完整示例

一个深度学习项目包括了模型的设计、损失函数的设计、梯度更新的方法、模型的保存与加载、模型的训练过程、训练过程可视化、使用模型进行预测等几个主要模块

以 CIFAR10 数据集为例

自定义模型

import torch
import torch.nn as nn


class My_CIFAR10_Model(nn.Module):
    def __init__(self):
        super(My_CIFAR10_Model, self).__init__()
        # 如果特别复杂的神经网络模型,可以拆成多个模型来forward,思路是其实跟Sequential是一样的
        self.model = nn.Sequential(
            # 进行1次卷积,in_channel=3 out_channel=32 卷积核尺寸=5 步长=1 padding=2
            nn.Conv2d(3, 32, 5, 1, 2),
            # 进行一次最大池化,池化核尺寸=2
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            # 摊平
            nn.Flatten(),
            # 线性层,in_feature=64channel*4h*4w(摊平后的len) out_feature=64
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        '''
        正向传播
        :param x:
        :return:
        '''
        return self.model(x)


if __name__ == '__main__':
    m = My_CIFAR10_Model()
    input = torch.ones((64, 3, 32, 32))
    output = m(input)
    print(output.shape)

模型的保存与加载


# 保存方式1,直接保存模型
torch.save(model, "my_CIFAR10_model.pth")

# 保存方式2,只保存模型参数(官方推荐)
torch.save(model.state_dict(), "my_CIFAR10_model2.pth")

# 加载方式1(对应保存方式1)
# 注意:还是需要引入模型,如:from bi.CIFAModel import CIFAModel
model1 = torch.load("my_CIFAR10_model.pth")

# 加载方式2(对应保存方式2)
from bi.CIFAModel import CIFAModel
model2 = CIFAModel()
model2.load_state_dict(torch.load("my_CIFAR10_model2.pth"))

CPU和GPU的切换

需要切换的对象

  • 神经网络模型(neural network)
  • 数据(输入、标签)
  • 损失函数
# 方式1:判断下是否支持GPU,如果支持,调用:.cuda()
if torch.cuda.is_available():
    model.cuda()

if torch.cuda.is_available():
    loss_fn.cuda()

if torch.cuda.is_available():
    imgs = imgs.cuda()

if torch.cuda.is_available():
    labels = labels.cuda()

# 方式2:定义设备,然后后面的代码调用:.to(device),不需要频繁判断了
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 如果有多个显卡的话,也可以指定具体的显卡,默认第0个
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 模型和损失函数可以不接收返回值(也可以接收)
model.to(device)
loss_fn.to(device)
# Tensor必须接收返回值(比如:从CPU的Tensor到GPU的Tensor)
imgs = imgs.to(device)
labels = labels.to(device)

模型的可视化

新版pytorch已经集成了TensorBoardX,from torch.utils.tensorboard import *

我们可以使用 TensorBoard 来进行模型的可视化

老版本的tensorboard可能需要依赖tensorflow运行

安装

pip install tensorboard

pip install tensorflow

写入标量示例

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./tensorboard_log')

for x in range(100):
    writer.add_scalar("y=2x", 2 * x, x)
writer.close()

写入图片示例

from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np

img_np = np.array(Image.open('./1.png'))
writer = SummaryWriter("./tensorboard_log")
writer.add_image("img_test", img_np, dataformats="HWC")
writer.close()

写入matplotlib示例

x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]
plt.plot(x, y)
# plt.gcf():Get Current Figure,plt.gca():Get Current Axes
writer.add_figure('matplotlib.img', plt.gcf(), i)

运行TensorBoard

tensorboard --logdir=./tensorboard_log

tensorboard --logdir=./tensorboard_log --port=6006

TensorboardX还有很多其他的方法,如:add_histogram、add_graph、add_embedding、add_audio等,可以去官网查看

模型训练

import torchvision
from torch.utils.data import DataLoader
from ai.My_CIFAR10_Model import My_CIFAR10_Model
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import time
import platform

# 支持gpu则用gpu,否则用cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if platform.system() == 'Windows':
    test_data_dir = r'd:/wwwroot/python/ai/test_data'
    tensorboard_dir = r'd:/wwwroot/python/ai/tensorboard_logs'
    model_dir = r'd:/wwwroot/python/ai/model_data'
else:
    test_data_dir = './test_data'
    tensorboard_dir = './tensorboard_logs'
    model_dir = './model_data'

# 准备数据集
train_data = torchvision.datasets.CIFAR10(root=test_data_dir, train=True, transform=torchvision.transforms.ToTensor(),
                                          download=False)

test_data = torchvision.datasets.CIFAR10(root=test_data_dir, train=False, transform=torchvision.transforms.ToTensor(),
                                         download=False)

# 50000
train_data_len = len(train_data)
print(train_data_len)
# 10000
test_data_len = len(test_data)
print(test_data_len)

# 网络模型
model = My_CIFAR10_Model()
# 支持gpu则用gpu,否则用cpu
model.to(device)

# 损失函数,交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()
# 支持gpu则用gpu,否则用cpu
loss_fn.to(device)

# 优化器,随机梯度下降
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# 训练轮数
epoch = 100
# 训练次数
total_train_step = 0
# 测试次数
total_test_step = 0

# tensorboard
writer = SummaryWriter(tensorboard_dir)

for i in range(epoch):
    start_time = time.time()
    print(f"-------------第 {i + 1} 轮训练开始-------------")

    # 切换到训练模式,有一些特殊层(如:batchnormal、dropout等)会有效果,默认调用下就行了(具体看文档:nn.Module)
    model.train()
    for data in DataLoader(train_data, batch_size=64):
        imgs, labels = data
        # 支持gpu则用gpu,否则用cpu
        imgs = imgs.to(device)
        labels = labels.to(device)
        outputs = model(imgs)
        loss = loss_fn(outputs, labels)

        # 优化器优化模型,梯度清零、反向传播、优化参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 以下是不常用的梯度下降方式,了解即可,一般都用backward
        # torch.autograd.grad(loss, xxx)

        total_train_step += 1
        # 打印太频繁了,每训练100次打印1次
        if total_train_step % 100 == 0:
            print(f"训练次数:{total_train_step},loss:{loss}")
            writer.add_scalar('train_loss', loss.item(), total_train_step)

    # 切换到评估模式,禁止一些特殊层(如:batchnormal、dropout等)的效果,默认调用下就行了(具体看文档:nn.Module)
    model.eval()
    total_test_loss = 0
    total_test_accuracy = 0
    # 每轮训练结束,进行一次测试集数据的测试,使用测试集测试时不进行梯度调优
    with torch.no_grad():
        for data in DataLoader(test_data, batch_size=64):
            imgs, labels = data
            # 支持gpu则用gpu,否则用cpu
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
            total_test_loss += loss.item()
            # 取1轴上最大的索引(正确率最高的分类)与labels(正确答案)对比,返回的是布尔索引,然后对布尔索引求和就是正确的个数
            total_test_accuracy += (outputs.argmax(1) == labels).sum()
        total_test_step += 1
        print(f"测试次数:{total_test_step},loss:{total_test_loss}")
        print(f"测试次数:{total_test_step},正确率:{total_test_accuracy / test_data_len}")
        writer.add_scalar('test_loss', total_test_loss, total_test_step)
        writer.add_scalar('test_accuracy', total_test_accuracy / test_data_len, total_test_step)

    # 每轮训练结束,保存下模型参数
    torch.save(model.state_dict(), f'{model_dir}/my_CIFAR10_{i}.pth')
    print(f'本轮训练耗时:{time.time() - start_time}')

writer.close()

模型预测

from PIL import Image
import torch
from ai.My_CIFAR10_Model import My_CIFAR10_Model
import torchvision
import platform

# 支持gpu则用gpu,否则用cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if platform.system() == 'Windows':
    test_data_dir = r'd:/wwwroot/python/ai/test_data/dogs'
    model_dir = r'd:/wwwroot/python/ai/model_data'
else:
    test_data_dir = './test_data/dogs'
    model_dir = './model_data'

label_maps = {
    0: '飞机',
    1: '汽车',
    2: '鸟',
    3: '猫',
    4: '鹿',
    5: '狗',
    6: '青蛙',
    7: '马',
    8: '船',
    9: '卡车',
}

model = My_CIFAR10_Model()
# 模型在GPU上训练的,测试的时候用的是CPU,用:map_location=torch.device(device)兼容下
model.load_state_dict(torch.load(f"{model_dir}/my_CIFAR10_99.pth", map_location=torch.device(device)))
# 模型的参数是32的,所以Resize一下,再转成Tensor
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor()])

# 测试的4个图片都是狗
dogs = [
    'tiantian.png',
    'xiaoli.png',
    'aqi.png',
    'maomao.png'
]

for dog in dogs:
    img = Image.open(f'{test_data_dir}/{dog}')
    # png是RGBA的,转成RGB
    img = transform(img.convert("RGB"))
    # 切换到评估模式
    model.eval()
    # 预测时不进行梯度调优
    with torch.no_grad():
        output = model(torch.reshape(img, (1, 3, 32, 32)))
        # 本示例主要梳理pytorch从模型定义、训练、测试的流程,准确率不是很高:60%多点,所以预测了3狗1猫
        print(f'{dog}是:{label_maps[output.argmax(1).item()]}')

results matching ""

    No results matching ""