深度学习技巧应用5-神经网络中的模型剪枝技巧
创始人
2025-06-01 14:27:14

大家好,我是微学AI,今天给大家带来深度学习技巧应用5-神经网络中的模型剪枝技巧,模型剪枝是深度学习中的一个重要的技巧应用,用好了可以简化模型已经提高模型推理深度。

一、模型剪枝技巧介绍

模型剪枝是一种常用的深度学习模型优化技巧,其目的是通过去除模型中一些不必要的参数或节点,从而提高模型的运行效率和准确性。在模型剪枝技巧中,最常见的方法是结构化剪枝和非结构化剪枝。其中,结构化剪枝对模型的结构进行优化,例如对整个卷积层或全连接层进行剪枝;非结构化剪枝则对模型的参数进行优化,通常是通过数值的大小来判断参数的重要性,然后将数值较小的参数删除掉。

模型剪枝就像是整理植物的枝干一样,将模型中不必要的枝干切除,让模型更加紧凑。就像一个苹果树上的枝干一样,如果有太多的枝干,会浪费苹果树能量的同时也会影响果实的质量和产量。同样的,模型中的参数如果过多,会降低模型的训练速度和推理速度,同时也可能会过拟合数据。因此我们需要通过剪枝来减少模型中的不必要的权重和神经元。

 在模型剪枝中,我们对模型中的一些参数做出修改,使得它们的影响减少,从而降低整个模型的冗余度。这些修改包括减少层数,减小卷积核的大小,以及删除某些层中不需要的节点等。通过这些修改,在不影响模型精度的前提下,我们可以减少模型的大小和训练时间。

二、模型剪枝代码案例

在PyTorch框架中,可以通过torch.nn.utils.prune库中的函数来实现模型剪枝。我将使用ResNet18模型,在训练MNIST数据集时进行结构化剪枝,从而削减卷积层的参数量。

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.optim as optim
import torchvision
import torchvision.transforms as transformstransform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])trainset = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)testset = torchvision.datasets.MNIST(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)class ResNet18(nn.Module):def __init__(self):super(ResNet18, self).__init__()self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.layer1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(64))self.layer2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(128))self.layer3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(256))self.layer4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(512))self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512, 10)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)x = self.fc(x)return x
#模型剪枝
def model_pruning(model):layer1 = model.conv1prune.random_unstructured(layer1, name="weight", amount=0.3)prune.remove(layer1, 'weight')for name, module in model.named_modules():if isinstance(module, torch.nn.Conv2d):prune.l1_unstructured(module, name='weight', amount=0.5)prune.remove(module, 'weight')elif isinstance(module, torch.nn.Linear):prune.l1_unstructured(module, name='weight', amount=0.5)prune.remove(module, 'weight')return 0device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = ResNet18()
model_pruning(net)
net.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

在模型中对ResNet18的所有Conv2d,Linear层进行剪枝,使用L1正则化的方法,将50%的权重参数裁剪掉。代码使用`self.named_modules()`遍历了模型中的所有层,当遇到一个`nn.Conv2d`层时,就使用`prune.l1_unstructured`函数对其进行剪枝。`prune.l1_unstructured`剪枝方法会移除模块中最小的abs(weight)* amount个参数,并将它们设置为0。其中,`amount`表示要裁剪掉的权重比例,这里设置为0.5即50%。

模型剪枝后主要的变化主要有:

1.部分权重参数被裁剪掉了,模型的稀疏性增加。

2.网络的计算复杂度会减小,从而在一定程度上提高了推理速度。

3.可能会对模型的精度产生一定的影响,因为裁剪掉的参数中可能包含了一些对模型来说很重要的信息。

模型训练:

if __name__=="__main__":for epoch in range(1):  # 训练5个轮次running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2 == 0:  # 每100个批次输出一下当前的训练状态print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2))running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

合作与问题都可以私聊,获取更多资料。

相关内容

热门资讯

【实验报告】实验一 图像的... 实验目的熟悉Matlab图像运算的基础——矩阵运算;熟悉图像矩阵的显示方法࿰...
MATLAB | 全网最详细网... 一篇超超超长,超超超全面网络图绘制教程,本篇基本能讲清楚所有绘制要点&#...
大模型落地比趋势更重要,NLP... 全球很多人都开始相信,以ChatGPT为代表的大模型,将带来一场NLP领...
Linux学习之端口、网络协议... 端口:设备与外界通讯交流的出口 网络协议:   网络协议是指计算机通信网...
kuernetes 资源对象分... 文章目录1. pod 状态1.1 容器启动错误类型1.2 ImagePullBackOff 错误1....
STM32实战项目-数码管 程序实现功能: 1、上电后,数码管间隔50ms计数; 2、...
TM1638和TM1639差异... TM1638和TM1639差异说明 ✨本文不涉及具体的单片机代码驱动内容,值针对芯...
Qt+MySql开发笔记:Qt... 若该文为原创文章,转载请注明原文出处 本文章博客地址:https://h...
Java内存模型中的happe... 第29讲 | Java内存模型中的happen-before是什么? Java 语言...
《扬帆优配》算力概念股大爆发,... 3月22日,9股封单金额超亿元,工业富联、鸿博股份、鹏鼎控股分别为3.0...
CF1763D Valid B... CF1763D Valid Bitonic Permutations 题目大意 拱形排列࿰...
SQL语法 DDL、DML、D... 文章目录1 SQL通用语法2 SQL分类3 DDL 数据定义语言3.1 数据库操作3.2 表操作3....
文心一言 VS ChatGPT... 3月16号,百度正式发布了『文心一言』,这是国内公司第一次发布类Chat...
CentOS8提高篇5:磁盘分...        首先需要在虚拟机中模拟添加一块新的硬盘设备,然后进行分区、格式化、挂载等...
Linux防火墙——SNAT、... 目录 NAT 一、SNAT策略及作用 1、概述 SNAT应用环境 SNAT原理 SNAT转换前提条...
部署+使用集群的算力跑CPU密... 我先在开头做一个总结,表达我最终要做的事情和最终环境是如何的,然后我会一...
Uploadifive 批量文... Uploadifive 批量文件上传_uploadifive 多个上传按钮_asing1elife的...
C++入门语法基础 文章目录:1. 什么是C++2. 命名空间2.1 域的概念2.2 命名...
2023年全国DAMA-CDG... DAMA认证为数据管理专业人士提供职业目标晋升规划,彰显了职业发展里程碑及发展阶梯定义...
php实现助记词转TRX,ET... TRX助记词转地址网上都是Java,js或其他语言开发的示例,一个简单的...
【分割数据集操作集锦】毕设记录 1. 按要求将CSV文件转成json文件 有时候一些网络模型的源码会有data.json这样的文件里...
Postman接口测试之断言 如果你看文字部分还是不太理解的话,可以看看这个视频,详细介绍postma...
前端学习第三阶段-第4章 jQ... 4-1 jQuery介绍及常用API导读 01-jQuery入门导读 02-JavaScri...
4、linux初级——Linu... 目录 一、用CRT连接开发板 1、安装CRT调试工具 2、连接开发板 3、开机后ctrl+c...
Urban Radiance ... Urban Radiance Fields:城市辐射场 摘要:这项工作的目标是根据扫描...
天干地支(Java) 题目描述 古代中国使用天干地支来记录当前的年份。 天干一共有十个,分别为:...
SpringBoot雪花ID长... Long类型精度丢失 最近项目中使用雪花ID作为主键,雪花ID是19位Long类型数...
对JSP文件的理解 JSP是java程序。(JSP本质还是一个Servlet) JSP是&#...
【03173】2021年4月高... 一、单向填空题1、大量应用软件开发工具,开始于A、20世纪70年代B、20世纪 80年...
LeetCode5.最长回文子... 目录题目链接题目分析解题思路暴力中心向两边拓展搜索 题目链接 链接 题目分析 简单来说࿰...