昇思MindSpore学习笔记4-02生成式--DCGAN生成漫画头像

摘要:

        记录了昇思MindSpore AI框架使用70171张动漫头像图片训练一个DCGAN神经网络生成式对抗网络,并用来生成漫画头像的过程、步骤。包括环境准备、下载数据集、加载数据和预处理、构造网络、模型训练等。

一、概念

深度卷积对抗生成网络DCGAN

Deep Convolutional Generative Adversarial Networks

        扩展GAN

        判别器

                组成

                        卷积层

                        BatchNorm层

                        LeakyReLU激活层

                功能

                        输入是3*64*64图像

                        输出是真图像概率

        生成器

                组成

                        转置卷积层

                        BatchNorm层

                        ReLU激活层

                功能

                        输入是标准正态分布中提取出的隐向量z

                        输出是3*64*64 RGB图像。

  • 环境准备
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

三、数据准备与处理

1.下载数据集

下载到指定目录下并解压代码如下

from download import download
url = "https://download.mindspore.cn/dataset/Faces/faces.zip"
path = download(url, "./faces", kind="zip", replace=True)

输出:

Downloading data from https://download-mindspore.osinfra.cn/dataset/Faces/faces.zip (274.6 MB)

file_sizes: 100%|████████████████████████████| 288M/288M [00:52<00:00, 5.49MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./faces

2.数据集介绍

使用的动漫头像数据集共有70,171张动漫头像图片,图片大小均为96*96。

数据集目录结构如下:

./faces/faces
├── 0.jpg
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg
    ...
├── 70169.jpg
└── 70170.jpg

3.数据处理

(1) 执行过程参数定义:

batch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 3            # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数

(2) 数据处理和增强

create_dataset_imagenet函数

import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
​
def create_dataset_imagenet(dataset_path):
    """数据加载"""
    dataset = ds.ImageFolderDataset(dataset_path,
                                    num_parallel_workers=4,
                                    shuffle=True,
                                    decode=True)
​
    # 数据增强操作
    transforms = [
        vision.Resize(image_size),
        vision.CenterCrop(image_size),
        vision.HWC2CHW(),
        lambda x: ((x / 255).astype("float32"))
    ]
​
    # 数据映射操作
    dataset = dataset.project('image')
    dataset = dataset.map(transforms, 'image')
​
    # 批量操作
    dataset = dataset.batch(batch_size)
    return dataset
​
dataset = create_dataset_imagenet('./faces')

(3) 查看训练数据

matplotlib模块

数据转换成字典迭代器

        create_dict_iterator函数

import matplotlib.pyplot as plt
​
def plot_data(data):
    # 可视化部分训练数据
    plt.figure(figsize=(10, 3), dpi=140)
    for i, image in enumerate(data[0][:30], 1):
        plt.subplot(3, 10, i)
        plt.axis("off")
        plt.imshow(image.transpose(1, 2, 0))
    plt.show()
​
sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

四、构造网络

模型权重随机初始化

范围:mean为0,sigma为0.02的正态分布【数学不好】

1. 生成器

生成器G

        隐向量z映射数据空间

        数据源是图像

        生成与图像大小相同的 RGB 图像

        Conv2dTranspose转置卷积层

        每个层与BatchNorm2d层和ReLu激活层配对

        tanh函数

        输出[-1,1]范围内数据

DCGAN生成图像过程如下所示:

生成器结构参数:

        nz         隐向量z的长度

        ngf         有关生成器传播的特征图大小

        nc         输出图像通道数

生成器代码:

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normal
​
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
​
class Generator(nn.Cell):
    """DCGAN网络生成器"""
​
    def __init__(self):
        super(Generator, self).__init__()
        self.generator = nn.SequentialCell(
            nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),
            nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.Tanh()
            )
​
    def construct(self, x):
        return self.generator(x)
​
generator = Generator()

2. 判别器

判别器D

        二分类网络模型

                Conv2d

                BatchNorm2d

                LeakyReLU

                Sigmoid激活函数

                输出判定图像真实概率

判别器代码:

class Discriminator(nn.Cell):
    """DCGAN网络判别器"""
​
    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.SequentialCell(
            nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),
            )
        self.adv_layer = nn.Sigmoid()
​
    def construct(self, x):
        out = self.discriminator(x)
        out = out.reshape(out.shape[0], -1)
        return self.adv_layer(out)
​
discriminator = Discriminator()

五、模型训练

1. 损失函数

二进制交叉熵损失函数MindSpore.nn.BCELoss

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

2. 优化器

Adam优化器

        lr = 0.0002

        beta1 = 0.5

# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

3. 训练模型

训练判别器

        提高判别图像真伪的概率

        Goodfellow方法:提高随机梯度更新判别器

        最大化logD(x)+log(1-D(G(z)))

训练生成器

        最小化log(1−D(G(z)))

        产生更好的虚拟图像

两个部分分别

        获取训练损失

        每个周期结束统计

        批量推送fixed_noise到生成器

        跟踪G的训练进度

模型训练正向逻辑:

def generator_forward(real_imgs, valid):
    # 将噪声采样为发生器的输入
    z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))
​
    # 生成一批图像
    gen_imgs = generator(z)
​
    # 损失衡量发生器绕过判别器的能力
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)
​
    return g_loss, gen_imgs
​
def discriminator_forward(real_imgs, gen_imgs, valid, fake):
    # 衡量鉴别器从生成的样本中对真实样本进行分类的能力
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    fake_loss = adversarial_loss(discriminator(gen_imgs), fake)
    d_loss = (real_loss + fake_loss) / 2
    return d_loss
​
grad_generator_fn = ms.value_and_grad(generator_forward, None,
                                      optimizer_G.parameters,
                                      has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,
                                          optimizer_D.parameters)
​
@ms.jit
def train_step(imgs):
    valid = ops.ones((imgs.shape[0], 1), mindspore.float32)
    fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)
​
    (g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)
    optimizer_G(g_grads)
    d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)
    optimizer_D(d_grads)
​
    return g_loss, d_loss, gen_imgs

循环训练网络

        迭代50次收集生成器、判别器的损失一次

        绘制损失函数的图像

import mindspore
​
G_losses = []
D_losses = []
image_list = []
​
total = dataset.get_dataset_size()
for epoch in range(num_epochs):
    generator.set_train()
    discriminator.set_train()
    # 为每轮训练读入数据
    for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):
        g_loss, d_loss, gen_imgs = train_step(imgs)
        if i % 100 == 0 or i == total - 1:
            # 输出训练记录
            print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (
                epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))
        D_losses.append(d_loss.asnumpy())
        G_losses.append(g_loss.asnumpy())
​
    # 每个epoch结束后,使用生成器生成一组图片
    generator.set_train(False)
    fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
    img = generator(fixed_noise)
    image_list.append(img.transpose(0, 2, 3, 1).asnumpy())
​
    # 保存网络模型参数为ckpt文件
    mindspore.save_checkpoint(generator, "./generator.ckpt")
    mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

输出:

[ 1/3][  1/549]   Loss_D: 0.2635  Loss_G: 4.8150
[ 1/3][101/549]   Loss_D: 0.4023  Loss_G: 4.9807
[ 1/3][201/549]   Loss_D: 0.2425  Loss_G: 1.6335
[ 1/3][301/549]   Loss_D: 0.5856  Loss_G: 0.6079
[ 1/3][401/549]   Loss_D: 0.1922  Loss_G: 4.3977
[ 1/3][501/549]   Loss_D: 0.1065  Loss_G: 2.3724
[ 1/3][549/549]   Loss_D: 0.1893  Loss_G: 1.6483
[ 2/3][  1/549]   Loss_D: 0.3370  Loss_G: 4.4347
[ 2/3][101/549]   Loss_D: 0.4681  Loss_G: 0.8623
[ 2/3][201/549]   Loss_D: 0.1856  Loss_G: 3.7501
[ 2/3][301/549]   Loss_D: 0.1932  Loss_G: 2.6333
[ 2/3][401/549]   Loss_D: 0.1310  Loss_G: 2.2524
[ 2/3][501/549]   Loss_D: 0.2531  Loss_G: 1.4690
[ 2/3][549/549]   Loss_D: 0.1192  Loss_G: 5.7166
[ 3/3][  1/549]   Loss_D: 0.0716  Loss_G: 2.9886
[ 3/3][101/549]   Loss_D: 0.1345  Loss_G: 2.6544
[ 3/3][201/549]   Loss_D: 0.1097  Loss_G: 2.8604
[ 3/3][301/549]   Loss_D: 0.2066  Loss_G: 6.1513
[ 3/3][401/549]   Loss_D: 0.0797  Loss_G: 3.2336
[ 3/3][501/549]   Loss_D: 0.2618  Loss_G: 4.0991
[ 3/3][549/549]   Loss_D: 0.5600  Loss_G:10.7509

4. 结果展示

描绘D和G损失与训练迭代的关系图:

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出:

显示隐向量fixed_noise训练生成的图像

import matplotlib.pyplot as plt
import matplotlib.animation as animation
​
def showGif(image_list):
    show_list = []
    fig = plt.figure(figsize=(8, 3), dpi=120)
    for epoch in range(len(image_list)):
        images = []
        for i in range(3):
            row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)
            images.append(row)
        img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
        plt.axis("off")
        show_list.append([plt.imshow(img)])
​
    ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
    ani.save('./dcgan.gif', writer='pillow', fps=1)
​
showGif(image_list)

输出:

训练次数增多,图像质量越好

num_epochs达到50以上,生成动漫头像图片与数据集较为相似

加载生成器网络模型参数文件来生成图像代码:

# 从文件中获取模型参数并加载到网络中
mindspore.load_checkpoint("./generator.ckpt", generator)
​
fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
img64 = generator(fixed_noise).transpose(0, 2, 3, 1).asnumpy()
​
fig = plt.figure(figsize=(8, 3), dpi=120)
images = []
for i in range(3):
    images.append(np.concatenate((img64[i * 8:(i + 1) * 8]), axis=1))
img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
plt.axis("off")
plt.imshow(img)
plt.show()

输出:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/774085.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MMSC物料库位扩充

MMSC物料库位扩充 输入事务码MMSC&#xff1a; 回车后添加新的库位即可&#xff1a; 代码实现&#xff0c;使用BDC *&------------------------------------------------* *&BDC的定义 *&------------------------------------------------* DATA gt_bdcdata T…

【UE5.1】Chaos物理系统基础——03 炸开几何体集

目录 步骤 一、通过径向向量将几何体集炸开 二、优化炸开效果——让破裂的碎块自然下落 三、优化炸开效果——让碎块旋转起来 四、优化炸开效果——让碎块旋转的越来越慢 步骤 一、通过径向向量将几何体集炸开 1. 打开上一篇中&#xff08;【UE5.1】Chaos物理系统基础—…

百度出品_文心快码Comate提升程序员效率

1.文心快码 文心快码包含指令、插件 和 知识三种功能&#xff0c; 1&#xff09;指令包含Base64编码、Base64解码、JSON转TS类型、JSON转YAML、JWT解码喂JSON。 2&#xff09;插件包含 3&#xff09;指令包含如下功能&#xff1a; 官网链接

Jenkins 强制杀job

有时候有的jenkins job运行时间太长&#xff0c;在jenkins界面点击x按钮进行abort&#xff0c;会失败&#xff1a; 这时候点击&#xff1a; “Click here to forcibly terminate running steps” 会进一步kill 任务&#xff0c;但是也还是有杀不掉的可能性。 终极武器是jenkin…

Aigtek电压放大器参数有哪些

电压放大器是广泛应用于电子电路中的一种重要电路元件&#xff0c;它主要用于将输入信号的电压放大到所需的输出电压水平。在设计和使用电压放大器时&#xff0c;我们需要了解并考虑一系列的参数和特性。本文将详细介绍电压放大器的主要参数&#xff0c;包括放大倍数、带宽、输…

springboot校园购物网站APP-计算机毕业设计源码041037

摘 要 21世纪的今天&#xff0c;随着社会的不断发展与进步&#xff0c;人们对于信息科学化的认识&#xff0c;已由低层次向高层次发展&#xff0c;由原来的感性认识向理性认识提高&#xff0c;管理工作的重要性已逐渐被人们所认识&#xff0c;科学化的管理&#xff0c;使信息存…

【国产开源可视化引擎Meta2d.js】图层

独立图层 每个图元都有先后绘画顺序&#xff0c;即每个图元拥有一个独立图层&#xff0c;即meta2d.data().pens的数组索引。 可以通过meta2d.top/bottom/up/down等函数改变独立图层顺序。 分组图层 通过标签可以标识一个分组图层&#xff0c;通过meta2d.find(图层标签)获取…

己内酰胺纯化除杂的最佳工艺

己内酰胺纯化除杂的最佳工艺包括结晶法、离子交换树脂法、精馏法和萃取法等&#xff0c;每种方法都有其特定的应用场景和优缺点。以下是对这些方法的详细介绍&#xff1a; 最佳工艺介绍 ● 结晶法&#xff1a;通过调节pH值&#xff0c;使己内酰胺在特定条件下结晶&#xff0…

yolov8环境安装(可修改代码版本,源代码安装)

下载下来源文件以后&#xff0c;进去文件目录&#xff0c;然后输入pip指令&#xff0c;即可安装yolov8 cd ultralytics-main pip install -e . 直接使用pip安装的情况 当你使用pip install ultralytics这样的命令安装YOLOv8时&#xff0c;你实际上是在从Python包索引&#x…

C#实战|账号管理系统:通用登录窗体的实现。

哈喽,你好啊,我是雷工! 本节记录登录窗体的实现方法,比较有通用性,所有的项目登录窗体实现基本都是这个实现思路。 一通百通,以下为学习笔记。 01 登录窗体的逻辑 用户在登录窗输入账号和密码,如果输入账号和密码信息正确,点击【登录】按钮,则跳转显示主窗体,同时在固…

AI提示词:一个能让你的AI提升10倍逻辑能力的提示词,只有这几个字,Kimi和GPT都适用!

昨天晚上和朋友聊天&#xff0c;聊到AI提示词在实际使用过程中的逻辑能力问题。 他也是一个AI提示词的重度使用者&#xff0c;但是会经常遇到一个问题&#xff1a;明明觉得自己的提示词描述的很清楚了&#xff0c;可是AI输出的内容还是达不到自己想要的效果。 今天给大家分享…

认识不-物联网“六域模型”有哪些有什么作用

如下参考源于苏州稳联授权可见认知域-感知域-网络域-应用域-管理域-安全域-物联网六域模型 苏州稳联 (iotrouter.cn) 认识物联网“六域模型”&#xff1a;构成与作用 “六域模型”是一个有效的框架。这个模型通过将物联网划分为六个相互关联的域&#xff0c;帮助我们更好地理…

【网络安全】漏洞挖掘之Spring Cloud注入漏洞

漏洞描述 Spring框架为现代基于java的企业应用程序(在任何类型的部署平台上)提供了一个全面的编程和配置模型。 Spring Cloud 中的 serveless框架 Spring Cloud Function 中的 RoutingFunction 类的 apply 方法将请求头中的“spring.cloud.function.routing-expression”参数…

免费的二级域名分发,您确定不要试试吗?

这是可爱的小盆友做的一个免费的二级域名&#xff0c;目前上线就送114514个积分&#xff0c;名额有限~ 首先进入>路明二级域名分发 - 免费稳定的二级域名分发服务 (kmyl.top)< 进来是这个样子&#xff08;如下图所示&#xff09; 话不多说&#xff0c;进入教程~ 第一章…

创建react的脚手架

Create React App 中文文档 (bootcss.com) 网址&#xff1a;creat-react-app.bootcss.com 主流的脚手架&#xff1a;creat-react-app 创建脚手架的方法&#xff1a; 方法一&#xff08;JS默认&#xff09;&#xff1a; 1. npx create-react-app my-app 2. cd my-app 3. …

Redis 7.x 系列【19】管道

有道无术&#xff0c;术尚可求&#xff0c;有术无道&#xff0c;止于术。 本系列Redis 版本 7.2.5 源码地址&#xff1a;https://gitee.com/pearl-organization/study-redis-demo 文章目录 1. 往返时间2. 管道技术3. 代码演示4. 其他批处理4.1 原生批处理命令4.2 事务4.3 脚本…

短链接学习day2

用户敏感信息脱敏展示&#xff1a; RequestParam 和 PathVariable的区别 注解是用于从request中接收请求的&#xff0c;两个都可以接收参数&#xff0c;关键点不同的是RequestParam 是从request里面拿取值&#xff0c;而 PathVariable 是从一个URI模板里面来填充。 PathVari…

CSDN导入本地md文件图片不能正常回显问题

标题 搭建图像仓库获取图片URL 路径替换 因为服务器读取不到本地图片&#xff0c;故不能正常回显&#xff0c;因此想要正常回显图片&#xff0c;我们首先要做的就是搭建一个可以存放图片的服务器&#xff0c;像你可以选择购买一个云服务器、FastDFS图片服务器、Minio多云对象存…

初阶数据结构二叉树练习系列(1)

这个系列的文章将带大家一起刷题&#xff0c;并且总结思路 温馨提示&#xff1a;本篇文章里的练习题仅适合刚学完二叉树的小白使用 相同的树 思路 情况分析&#xff1a;第一种情况&#xff1a;两棵树都为空 → 返回true 第二种情况&am…

【linux进程】进程地址空间(什么是进程地址空间?为什么要有进程地址空间?)

目录 一、前言 二、 程序的地址空间是真实的 --- 物理空间吗&#xff1f; 三、进程地址空间 &#x1f525; 操作系统是如何建立起进程与物理内存之间的联系的呢&#xff1f; &#x1f525;什么是进程地址空间&#xff1f; &#x1f525;为什么不能直接去访问物理内存&a…