卷积神经网络的计算
创始人
2025-05-28 13:00:04

计算卷积神经网络中的参数

填充

我们在使用多重卷积的时候,常常会丢失边缘像素。由于我们使用小卷积核,因此对于单个卷积,我们只可能丢失几个像素。但随着我们应用连着许多卷积层,累积丢失的像素就多了起来。解决这个问题的方法即为填充(padding):在输入图像的边界填充元素(基本填充的都是0)
通常,我们添加php_hph​行填充,和pwp_wpw​列填充,则输出的形状为
(nh−kh+ph+1)∗(nw−kw+pw+1)(n_h - k_h + p_h + 1)*(n_w-k_w+p_w+1)(nh​−kh​+ph​+1)∗(nw​−kw​+pw​+1)
许多情况下,我们需要设置ph=kh−1p_h=k_h-1ph​=kh​−1,pw=kw−1p_w=k_w-1pw​=kw​−1使输出具有相同的高度和宽度。这样可以更好的预测输出的情况。
1、假设khk_hkh​为奇数,则高度的两侧分别填充ph/2p_h/2ph​/2
2、假设khk_hkh​为偶数,则宽度的上测分别填充pw/2p_w/2pw​/2上取整,下侧下取整

步幅

每次滑动元素的数量叫做步幅(stride),为了高效计算或是缩减采样次数,卷积窗口可以跳过中间位置,每次滑动多个元素。
当垂直步幅为shs_hsh​、水平步幅为sws_wsw​输出的形状为
⌊(nh−kh+ph+sh)/sh⌋∗⌊(nw−kw+pw+sw)/sw⌋\lfloor(n_h - k _h + p_h + s_h)/s_h\rfloor *\lfloor(n_w-k_w+p_w+s_w)/s_w\rfloor⌊(nh​−kh​+ph​+sh​)/sh​⌋∗⌊(nw​−kw​+pw​+sw​)/sw​⌋

1X1的卷积层

需要特别注意的是一层一的卷积层,即kh=kw=1。它唯一的计算发生在通道上,经常用1X1的卷积层来代替全连接层减少运算。

池化层(pooling)

最大池化和平均池化
nn.MaxPool2d(size, stride, padding)

AlexNet

AlexNet与LeNet进行对比
在这里插入图片描述

import torch
from torch import nn
from d2l import torch as d2lnet = nn.Sequential(# 这里使用一个11*11的更大窗口来捕捉对象。# 同时,步幅为4,以减少输出的高度和宽度。# 另外,输出通道的数目远大于LeNetnn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),# 减小卷积窗口,使用填充为2来使得输入与输出的高和宽一致,且增大输出通道数nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),# 使用三个连续的卷积层和较小的卷积窗口。# 除了最后的卷积层,输出通道的数量进一步增加。# 在前两个卷积层之后,汇聚层不用于减少输入的高度和宽度nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Flatten(),# 这里,全连接层的输出数量是LeNet中的好几倍。使用dropout层来减轻过拟合nn.Linear(6400, 4096), nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(p=0.5),# 最后是输出层。由于这里使用Fashion-MNIST,所以用类别数为10,而非论文中的1000nn.Linear(4096, 10))

224224 stride为4,padding为1,核大小为1111
第一层conv2d(224-11 + 2 + 1)/4 = 54
MaxPooling (54-3+1)/2=26
Conv2d (26 - 5 + 4 + 1) = 26
……
全连接层参数需要进行计算
256 * 5 * 5 = 6400//属于torch中不那么简洁的地方

X = torch.randn(1, 1, 224, 224)
for layer in net:X=layer(X)print(layer.__class__.__name__,'output shape:\t',X.shape)

输出为

ReLU output shape:   torch.Size([1, 96, 54, 54])
MaxPool2d output shape:      torch.Size([1, 96, 26, 26])
Conv2d output shape:         torch.Size([1, 256, 26, 26])
ReLU output shape:   torch.Size([1, 256, 26, 26])
MaxPool2d output shape:      torch.Size([1, 256, 12, 12])
Conv2d output shape:         torch.Size([1, 384, 12, 12])
ReLU output shape:   torch.Size([1, 384, 12, 12])
Conv2d output shape:         torch.Size([1, 384, 12, 12])
ReLU output shape:   torch.Size([1, 384, 12, 12])
Conv2d output shape:         torch.Size([1, 256, 12, 12])
ReLU output shape:   torch.Size([1, 256, 12, 12])
MaxPool2d output shape:      torch.Size([1, 256, 5, 5])
Flatten output shape:        torch.Size([1, 6400])
Linear output shape:         torch.Size([1, 4096])
ReLU output shape:   torch.Size([1, 4096])
Dropout output shape:        torch.Size([1, 4096])
Linear output shape:         torch.Size([1, 4096])
ReLU output shape:   torch.Size([1, 4096])
Dropout output shape:        torch.Size([1, 4096])
Linear output shape:         torch.Size([1, 10])

d2l学习笔记

相关内容

热门资讯

【C++初阶】10 .习题1 2022-09-16_命名空间 1. 命名空间的概念 下面关于C++命名空间描述错...
基于bearpi的智能小车--... 基于bearpi的智能小车--Qt上位机设计 前言一、界面原型1.主界面2.网络配置子窗口模块 二、...
三、Java核心技术(进阶)-... 一、概念 国际化编程:通过一套软件适配多个语言包。 二、相关函数 java.util....
水声功率放大器与宽带匹配技术研...   作为声呐设备重要的一份子,水声信号发射机承担着非常重要的作用。水声信号发射机其实是...
【C++】12.继承 1.引入继承 学生管理系统 学生 老师 社管阿姨 保安大叔 4个类 4个类有很多重复的东西...
LINUX中atd和crond... 一、atd与crond的区别1、运行方式不同,at只运行一次,而cron...
C++数据结构 —— 哈希表、... 目录 1.哈希概念 1.1哈希函数 1.2哈希冲突 2.闭散列实现 3.开散列实现 4.容器的封装 ...
Streamlit 学习笔记1 Streamlit 学习笔记1 文章目录Streamlit 学习笔记1首先 安利下streamlit...
基层区域应用平台为目标开发的基... 系统特点:  JAVA语言开发,MYSQL数据库,B/S架构 基于云计算...
数智链接,新一代校园招聘解决方... 疫情3年市场巨变,00后新生代初登上求职舞台,中和作用下,...
面试官:rem和vw有什么区别 "rem" 和 "vw"的区别 "rem" 和 "vw" 都是用于网页设计的CSS单位。 "rem"...
Pytest自动化测试框架完美... 简介 Allure Framework是一种灵活的、轻量级、多语言测试报告工具。 不仅可以以简洁的网...
华为nat配置实验:内网能够访... 一 需求分析1.1 需求 公司A在北京,公司B在上海,本次实验仅仅模拟局...
事务日志与 两阶段提交 文章目录 Redo Logredo的优点redo的组成redo的整体流程不同刷盘策略演示 Undo ...
【目标跟踪算法】Strong ... 1. Strong SORT算法的背景和概述 Strong SORT算法基于经典的Deep SORT...
Lock接口——JUC随记2 1、synchronized 1.1、synchronized的三种应用方式 一. 修饰实例方法&#...
IO流之字符流 文章目录1. 字符流概述2. 编码表3. 编码和解码4. 写数据的方式5. 读数据的方式6. 转换流...
C语言的灵魂---指针(基础) C语言灵魂指针1.什么是指针?2.指针的大小3.指针的分类3.1比较常规的指针类型3....
【华为OD机试真题JAVA】最... 标题:最优策略组合下的总的系统消耗资源数问题 | 时间限制:1秒 | 内存限制:262144K | ...
MATLAB | 给热图整点花... 前段时间写的特殊热图绘制函数迎来大更新,基础使用教程可以看看这一篇: h...
小知识·BitTorrent ... BitTorrent 简介从 P2P 说起经常在网上飙车的老司机应该都知道 BT 下载,...
Redis和Memcached...         对于大多数的系统服务来说,缓存是提高性能和可伸缩性的关键。一般情况下我...
[牛客算法总结]:重建二叉树    标签: 二叉树、DFS、先序遍历、中序遍历、递归   题目: 给定...
VS Code 将推出更多 A... 大家好,欢迎来到我们的二月更新!我们将为您带来与 JUnit 5 并行测...
为什么要推荐使用pnpm 在谈起pnpm时先来聊一聊之前的npm和yarn有什么存在的问题  npm2 在npm3之前我们安装...
多线程开发 文章目录多线程开发1. Thread创建多线程2. ThreadPoolExecutor创建进程池a...
闪存系统性能优化方向?NAND... Hello 大家好, 我是元存储~ 目录 前言 1. 提升效果 2. Cache Re...
关于复杂链表的复制问题(力扣) 上面我们已经说了两个关于链表的实现了,其中一个是单链表,另外一个是双向带...
STM32学习(二) 常用开发工具简介 安装仿真器驱动 DAP仿真器免驱ST LINK仿真器驱动安装方法:...
K8s配置jenkins Ma... 1、k8s安装jenkins 以阿里云的ACK为例 A、在有状态点击镜像创建,配置自己...