Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理
创始人
2025-06-01 06:04:03

目录

  • 1 数据集Dataset
  • 2 数据加载DataLoader
  • 3 常用预处理方法
  • 4 模型处理
  • 5 实例:MNIST数据集处理

1 数据集Dataset

Dataset类是Pytorch中图像数据集操作的核心类,Pytorch中所有数据集加载类都继承自Dataset父类。当我们自定义数据集处理时,必须实现Dataset类中的三个接口:

  • 初始化
    def __init__(self)
    
    构造函数,定义一些数据集的公有属性,如数据集下载地址、名称等
  • 数据集大小
    def __len__(self)
    
    返回数据集大小,不同的数据集有不同的衡量数据量的方式
  • 数据集索引
    def __getitem__(self, index):
    
    支持数据集索引功能,以实现形如dataset[i]得到数据集中的第i + 1个数据的功能。__getitem__是后期迭代数据时执行的具体函数,其返回值决定了循环变量,例如
    class data(Dataset)...def __getitem__(self, idx: int):if self.transforms:img = self.transforms(img)return img, label			# 返回的值即为后续迭代的循环变量for images, labels in dataLoader:...
    

2 数据加载DataLoader

为什么有了数据集Dataset还需要数据加载器DataLoader呢?原因在于神经网络需要进一步借助DataLoader对数据进行划分,也就是我们常说的batch,此外DataLoader还实现了打乱数据集、多线程等操作。

DataLoader本质是一个可迭代对象,可以使用形如

for inputs, labels in dataloaders

进行可迭代对象的访问。

我们一般不需要去实现DataLoader的接口,只需要在构造函数中指定相应的参数即可,比如常见的batch_sizeshuffle等参数。

下面这张图非常好地说明了DatasetDataLoader的关系

在这里插入图片描述

接下来总结数据构造的三步法

  1. 继承Dataset对象,并实现__len__()__getitem__()魔法方法,该步骤的主要目的在于将文件形式的数据集处理为模型可用的标准数据格式,并加载到内存中;
  2. DataLoader对象封装Dataset,使其成为可迭代对象;
  3. 遍历DataLoader对象以将数据加载到模型中进行训练。

3 常用预处理方法

在数据集Dataset__getitem__()中利用torchvision.transforms进行数据预处理与变换

常见的数据预处理变换方法总结如下表

序号变换含义
1RandomCrop(size, ...)对输入图像依据给定size随机裁剪
2CenterCrop(size, ...)对输入图像依据给定size从中心裁剪
3RandomResizedCrop(size, ...)对输入图像随机长宽比裁剪,再放缩到给定size
4FiveCrop(size, ...)对输入图像进行上下左右及中心裁剪,返回五张图像(size)组成的四维张量
5TenCrop(size, vertical_flip=False)对输入图像进行上下左右及中心裁剪,再全部翻转(水平或垂直),返回十张图像(size)组成的四维张量
6RandomHorizontalFlip(p=0.5)对输入图像按概率p随机进行水平翻转
7RandomVerticalFlip(p=0.5)对输入图像按概率p随机进行垂直翻转
8RandomRotation(degree, ...)对输入图像在degree内随机旋转某角度
9Resize(size, ...)对输入图像重置分辨率
10Normalize(mean, std)对输入图像各通道进行标准化
11ToTensor()将输入图像或ndarray 转换为tensor并归一化
12Pad(padding, fill=0, padding_mode=‘constant’)对输入图像进行填充
13ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)对输入图像修改亮度、对比度、饱和度、色度等
14Grayscale(num_output_channels=1)对输入图像转灰度
15LinearTransformation(matrix)对输入图像进行线性变换
16RandomAffine(...)对输入图像进行仿射变换
17RandomGrayscale(p=0.1)对输入图像按概率p随机转灰度
18ToPILImage(mode=None)对输入图像转PIL格式图像
19RandomOrder()随机打乱transforms操作顺序

4 模型处理

考虑以下场景:

网络的部分层级结构已经收敛、无需调整;大型复杂网络需要微调(Fine-tune)某些结构或参数;希望基于已训练好的模型进行改善或其他研究工作。

这些场景下重新通过数据集训练整个神经网络并无必要,甚至会使模型不稳定,因此引入预训练(pretrained)。Pytorch允许用户保存已训练好的模型,或加载其他模型,避免往复的无谓重训练,其中模型参数文件以.pth为后缀

# 保存已训练模型
torch.save(model.state_dict(), path)
# 加载预训练模型
model.load_state_dict(torch.load(path), device)

通过设置模型某些层可学习参数的requires_grad属性为False即可固定这部分参数不被后续学习过程影响。深度学习框架应用优势之一在于预设了对GPU的支持,大大提高模型处理与训练的效率。Pytorch中通过mode.to(device)方法将模型部署到指定设备上(CPU/GPU),范式如下:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

工程上也常使用torch.nn.DataParallel(model, devices)来处理多GPU并行运算,其原理是:首先将模型加载到主GPU上,再将模型从主GPU产生若干副本到其余GPU,随后将一个batch中的数据按维度划分为不同的子任务给各GPU进行前向传播,得到的损失会被累积到主GPU上并由主GPU反向传播更新参数,最后将更新参数拷贝到其余GPU以开始下一轮训练。

5 实例:MNIST数据集处理

下面给出了处理MNIST手写数据集的完整代码,可以用于加深对数据处理流程的理解

from abc import abstractmethod
import numpy as np
from torchvision.datasets import mnist
from torch.utils.data import Dataset
from PIL import Imageclass mnistData(Dataset):'''* @breif: MNIST数据集抽象接口* @param[in]: dataPath -> 数据集存放路径* @param[in]: transforms -> 数据集变换'''    def __init__(self, dataPath: str, transforms=None) -> None:super().__init__()self.dataPath = dataPathself.transforms = transformsself.data, self.label = [], []def __len__(self) -> int:return len(self.label)def __getitem__(self, idx: int):img = self.data[idx]if self.transforms:img = self.transforms(img)return img, self.label[idx]@abstractmethoddef plot(self, index: int) -> None:pass@abstractmethoddef load(self) -> list:passdef plotData(self, index: int, info: str=None) -> None:'''* @breif: 可视化训练数据* @param[in]: index -> 数据集索引* @param[in]: info -> 备注信息* @retval: None'''print(info, " --index:", index, "--label:", self.label[index])  if info else \print(" --index:", index, "--label:", self.label[index])          img = Image.fromarray(np.uint8(self.data[index]))img.show()def loadData(self, train: bool) -> list:'''* @breif: 下载与加载数据集* @param[in]: train -> 是否为训练集* @retval: 数据与标签列表'''    # 如果指定目录下不存在数据集则下载dataSet   = mnist.MNIST(self.dataPath, train=train, download=True)# 初始化数据与标签data  = [ i[0] for i in dataSet ]label = [ i[1] for i in dataSet ]return data, labelclass mnistTrainData(mnistData):'''* @breif: MNIST训练集* @param[in]: dataPath -> 数据集存放路径* @param[in]: transforms -> 数据集变换'''    def __init__(self, dataPath: str, transforms=None) -> None:super().__init__(dataPath, transforms=transforms)self.data, self.label = self.load()def plot(self, index: int) -> None:self.plotData(index, "trainSet data")def load(self) -> list:return self.loadData(train=True)class mnistTestData(mnistData):'''* @breif: MNIST测试集* @param[in]: dataPath -> 数据集存放路径* @param[in]: transforms -> 数据集变换'''    def __init__(self, dataPath: str, transforms=None) -> None:super().__init__(dataPath, transforms=transforms)self.data, self.label = self.load()def plot(self, index: int) -> None:self.plotData(index, "testSet data")def load(self) -> list:return self.loadData(train=False)

在这里插入图片描述

相关内容

热门资讯

【实验报告】实验一 图像的... 实验目的熟悉Matlab图像运算的基础——矩阵运算;熟悉图像矩阵的显示方法࿰...
MATLAB | 全网最详细网... 一篇超超超长,超超超全面网络图绘制教程,本篇基本能讲清楚所有绘制要点&#...
大模型落地比趋势更重要,NLP... 全球很多人都开始相信,以ChatGPT为代表的大模型,将带来一场NLP领...
Linux学习之端口、网络协议... 端口:设备与外界通讯交流的出口 网络协议:   网络协议是指计算机通信网...
kuernetes 资源对象分... 文章目录1. pod 状态1.1 容器启动错误类型1.2 ImagePullBackOff 错误1....
STM32实战项目-数码管 程序实现功能: 1、上电后,数码管间隔50ms计数; 2、...
TM1638和TM1639差异... TM1638和TM1639差异说明 ✨本文不涉及具体的单片机代码驱动内容,值针对芯...
Qt+MySql开发笔记:Qt... 若该文为原创文章,转载请注明原文出处 本文章博客地址:https://h...
Java内存模型中的happe... 第29讲 | Java内存模型中的happen-before是什么? Java 语言...
《扬帆优配》算力概念股大爆发,... 3月22日,9股封单金额超亿元,工业富联、鸿博股份、鹏鼎控股分别为3.0...
CF1763D Valid B... CF1763D Valid Bitonic Permutations 题目大意 拱形排列࿰...
SQL语法 DDL、DML、D... 文章目录1 SQL通用语法2 SQL分类3 DDL 数据定义语言3.1 数据库操作3.2 表操作3....
文心一言 VS ChatGPT... 3月16号,百度正式发布了『文心一言』,这是国内公司第一次发布类Chat...
CentOS8提高篇5:磁盘分...        首先需要在虚拟机中模拟添加一块新的硬盘设备,然后进行分区、格式化、挂载等...
Linux防火墙——SNAT、... 目录 NAT 一、SNAT策略及作用 1、概述 SNAT应用环境 SNAT原理 SNAT转换前提条...
部署+使用集群的算力跑CPU密... 我先在开头做一个总结,表达我最终要做的事情和最终环境是如何的,然后我会一...
Uploadifive 批量文... Uploadifive 批量文件上传_uploadifive 多个上传按钮_asing1elife的...
C++入门语法基础 文章目录:1. 什么是C++2. 命名空间2.1 域的概念2.2 命名...
2023年全国DAMA-CDG... DAMA认证为数据管理专业人士提供职业目标晋升规划,彰显了职业发展里程碑及发展阶梯定义...
php实现助记词转TRX,ET... TRX助记词转地址网上都是Java,js或其他语言开发的示例,一个简单的...
【分割数据集操作集锦】毕设记录 1. 按要求将CSV文件转成json文件 有时候一些网络模型的源码会有data.json这样的文件里...
Postman接口测试之断言 如果你看文字部分还是不太理解的话,可以看看这个视频,详细介绍postma...
前端学习第三阶段-第4章 jQ... 4-1 jQuery介绍及常用API导读 01-jQuery入门导读 02-JavaScri...
4、linux初级——Linu... 目录 一、用CRT连接开发板 1、安装CRT调试工具 2、连接开发板 3、开机后ctrl+c...
Urban Radiance ... Urban Radiance Fields:城市辐射场 摘要:这项工作的目标是根据扫描...
天干地支(Java) 题目描述 古代中国使用天干地支来记录当前的年份。 天干一共有十个,分别为:...
SpringBoot雪花ID长... Long类型精度丢失 最近项目中使用雪花ID作为主键,雪花ID是19位Long类型数...
对JSP文件的理解 JSP是java程序。(JSP本质还是一个Servlet) JSP是&#...
【03173】2021年4月高... 一、单向填空题1、大量应用软件开发工具,开始于A、20世纪70年代B、20世纪 80年...
LeetCode5.最长回文子... 目录题目链接题目分析解题思路暴力中心向两边拓展搜索 题目链接 链接 题目分析 简单来说࿰...