UNet - 训练数据train
创始人
2024-04-09 12:03:29

目录

1. train 训练数据

2. Loss 值

3. 完整代码


1. train 训练数据

训练的代码只是在之前图像分类的基础上做了一些更改,具体的可以看下面的文章

pytorch 搭建 LeNet 网络对 CIFAR-10 图片分类https://blog.csdn.net/qq_44886601/article/details/127498256

首先,导入之前定义的UNet 网络

然后,加载训练集和测试集

因为加载数据集被重写过,所以这里传入的是训练的图像,然后根据里面的replace就能找到对应的标签

这里训练的时候可以将数据打乱,测试的时候没有必要,batch_size 因为电脑硬件的问题设置成2,再大的话这里内存就会不够了

 

然后定义优化器和损失函数,这里用的是BCE加上sigmoid的损失函数

训练的时候,要将模式改为train模式,然后训练的步骤很常规

梯度清零->前向传播->计算损失函数->反向传播->更新参数

 

这里测试的时候有些区别

因为这里UNet 网络的输出是一幅图像,而之前将label改为了二值图像(归一化后是0 1)。所以这里计算准确率的时候,将预测的图像也变为二值图像,计算准确率用的是对应图像像素点的灰度值是否相等的方法

 

最后保留最好准确率的那个参数就行了

 

2. Loss 值

这是跑了20 个epoch的输出

 

3. 完整代码

from model import UNet                  # 导入Unet 网络
from dataset import Data_Loader         # 数据处理
from torch import optim
import torch.nn as nn
import torch# 网络训练模块
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # GPU or CPU
print(device)
net = UNet(in_channels=1, num_classes=1)                                # 加载网络
net.to(device)                                                          # 将网络加载到device上# 加载训练集
train_path = "./data/train/image"
trainset = Data_Loader(train_path)
train_loader = torch.utils.data.DataLoader(dataset=trainset,batch_size=2,shuffle=True)# len(trainset)  样本总数:21# 加载测试集
test_path = "./data/test/image"
testset = Data_Loader(test_path)
test_loader = torch.utils.data.DataLoader(dataset=testset,batch_size=2)optimizer = optim.RMSprop(net.parameters(),lr = 0.000001,weight_decay=1e-8,momentum=0.9)     # 定义优化器
criterion = nn.BCEWithLogitsLoss()                                                           # 定义损失函数save_path = './UNet.pth'        # 网络参数的保存路径
best_acc = 0.0                  # 保存最好的准确率for epoch in range(20):net.train()     # 训练模式running_loss = 0.0for image,label in train_loader:                   # 读取数据和labeloptimizer.zero_grad()                          # 梯度清零pred = net(image.to(device))                   # 前向传播loss = criterion(pred, label.to(device))       # 计算损失loss.backward()                                # 反向传播optimizer.step()                               # 梯度下降running_loss += loss.item()                    # 计算损失和net.eval()  # 测试模式acc = 0.0   # 正确率total = 0with torch.no_grad():for test_image, test_label in test_loader:outputs = net(test_image.to(device))     # 前向传播outputs[outputs >= 0] = 1  # 将预测图片转为二值图片outputs[outputs < 0] = 0acc += (outputs == test_label.to(device)).sum().item() / (480*480)     # 计算预测图片与真实图片像素点一致的精度:acc = 相同的 / 总个数total += test_label.size(0)accurate = acc / total  # 计算整个test上面的正确率print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f %%' %(epoch + 1, running_loss, accurate*100))if accurate > best_acc:     # 保留最好的精度best_acc = accuratetorch.save(net.state_dict(), save_path)     # 保存网络参数

相关内容

热门资讯

埃菲尔铁塔在哪 中国仿建埃菲尔... 2019年4月26日,广西南宁市,街头惊现一座巨型山寨版埃菲尔铁塔,高约20米,白色塔身,造型逼真,...
苗族的传统节日 贵州苗族节日有... 【岜沙苗族芦笙节】岜沙,苗语叫“分送”,距从江县城7.5公里,是世界上最崇拜树木并以树为神的枪手部落...
北京的名胜古迹 北京最著名的景... 北京从元代开始,逐渐走上帝国首都的道路,先是成为大辽朝五大首都之一的南京城,随着金灭辽,金代从海陵王...
长白山自助游攻略 吉林长白山游... 昨天介绍了西坡的景点详细请看链接:一个人的旅行,据说能看到长白山天池全凭运气,您的运气如何?今日介绍...
应用未安装解决办法 平板应用未... ---IT小技术,每天Get一个小技能!一、前言描述苹果IPad2居然不能安装怎么办?与此IPad不...
脚上的穴位图 脚面经络图对应的... 人体穴位作用图解大全更清晰直观的标注了各个人体穴位的作用,包括头部穴位图、胸部穴位图、背部穴位图、胳...
猫咪吃了塑料袋怎么办 猫咪误食... 你知道吗?塑料袋放久了会长猫哦!要说猫咪对塑料袋的喜爱程度完完全全可以媲美纸箱家里只要一有塑料袋的响...
demo什么意思 demo版本... 618快到了,各位的小金库大概也在准备开闸放水了吧。没有小金库的,也该向老婆撒娇卖萌服个软了,一切只...
世界上最漂亮的人 世界上最漂亮... 此前在某网上,选出了全球265万颜值姣好的女性。从这些数量庞大的女性群体中,人们投票选出了心目中最美...
埃菲尔铁塔在哪 中国仿建埃菲尔... 2019年4月26日,广西南宁市,街头惊现一座巨型山寨版埃菲尔铁塔,高约20米,白色塔身,造型逼真,...
苗族的传统节日 贵州苗族节日有... 【岜沙苗族芦笙节】岜沙,苗语叫“分送”,距从江县城7.5公里,是世界上最崇拜树木并以树为神的枪手部落...
北京的名胜古迹 北京最著名的景... 北京从元代开始,逐渐走上帝国首都的道路,先是成为大辽朝五大首都之一的南京城,随着金灭辽,金代从海陵王...
长白山自助游攻略 吉林长白山游... 昨天介绍了西坡的景点详细请看链接:一个人的旅行,据说能看到长白山天池全凭运气,您的运气如何?今日介绍...
世界上最漂亮的人 世界上最漂亮... 此前在某网上,选出了全球265万颜值姣好的女性。从这些数量庞大的女性群体中,人们投票选出了心目中最美...
应用未安装解决办法 平板应用未... ---IT小技术,每天Get一个小技能!一、前言描述苹果IPad2居然不能安装怎么办?与此IPad不...
脚上的穴位图 脚面经络图对应的... 人体穴位作用图解大全更清晰直观的标注了各个人体穴位的作用,包括头部穴位图、胸部穴位图、背部穴位图、胳...
demo什么意思 demo版本... 618快到了,各位的小金库大概也在准备开闸放水了吧。没有小金库的,也该向老婆撒娇卖萌服个软了,一切只...
猫咪吃了塑料袋怎么办 猫咪误食... 你知道吗?塑料袋放久了会长猫哦!要说猫咪对塑料袋的喜爱程度完完全全可以媲美纸箱家里只要一有塑料袋的响...