旷视科技ShuffleNetV1代码分析[pytorch版]

一、前述 

旷视科技针对于ShuffleNet系列网络在GitHub网站上已开源,其链接:https://github.com/megvii-model/ShuffleNet-Series

在这个系列中,包括了ShuffleNetV1/V2网络,如下图所示。 

我们点开ShuffleNetV1文件夹,如下图所示。 

  • ShuffleNetV1文件夹中有五个文件,分别为:README.md、blocks.py、network.py、train.py、utils.py文件。
  • 其中,blocks.py中的代码是ShuffleNetV1的基本模块;
  • network.py 中的代码是 blocks.py 中基本模块堆叠出来的 ShuffleNetV1 网络;
  • train.py 中是训练 ImageNet 数据集图像分类的训练代码;
  • utils.py 是一些常用的函数。

旷视科技GitHub网站给出的ShufflNetV1网络的结果,如下表所示: 

二、代码分析

2.1 blocks.py(ShuffleNetV1 Unit) 

我们先来回顾以下ShuffleNetV1 Unit,如下图(b)、图(c)所示。 
图(b)表示的是stride=1的ShuffleNetV1 Unit,在该基本单元中,右侧被称为主分支,在该主分支中:
①先1×1GConv(group pointwise convolution)降维,第一个红色模块;
②然后channel shuffle,蓝色模块;
③再然后3×3DWConv(depthwise convolution),绿色模块;
④然后再1×1GConv升维,第二个红色模块。

图(c)表示的是stride=2 的ShuffleNetV1 Unit,在该基本单元中,右侧被称为主分支,在该主分支中:
①先1×1GConv(group pointwise convolution)降维,第一个红色模块;
②然后channel shuffle,蓝色模块;
③再然后3×3DWConv(depthwise convolution),绿色模块;
④然后再1×1GConv升维,第二个红色模块。

ShuffleNetV1 Unit
(b)stride = 1; (c)stride = 2

ShuffleNetV1网络基本模块的总体代码如下所示,该代码包括了:stride=1的基本单元构建、stride=2的基本单元构建、channel shuffle(通道重排)操作。 

# blocks.py
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ShuffleV1Block(nn.Module):def __init__(self, inp, oup, *, group, first_group, mid_channels, ksize, stride):super(ShuffleV1Block, self).__init__()self.stride = strideassert stride in [1, 2]self.mid_channels = mid_channelsself.ksize = ksizepad = ksize // 2self.pad = padself.inp = inpself.group = groupif stride == 2:outputs = oup - inpelse:outputs = oupbranch_main_1 = [# pwnn.Conv2d(inp, mid_channels, 1, 1, 0, groups=1 if first_group else group, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),# dwnn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),nn.BatchNorm2d(mid_channels),]branch_main_2 = [# pw-linearnn.Conv2d(mid_channels, outputs, 1, 1, 0, groups=group, bias=False),nn.BatchNorm2d(outputs),]self.branch_main_1 = nn.Sequential(*branch_main_1)self.branch_main_2 = nn.Sequential(*branch_main_2)if stride == 2:self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)def forward(self, old_x):x = old_xx_proj = old_xx = self.branch_main_1(x)if self.group > 1:x = self.channel_shuffle(x)x = self.branch_main_2(x)if self.stride == 1:return F.relu(x + x_proj)elif self.stride == 2:return torch.cat((self.branch_proj(x_proj), F.relu(x)), 1)def channel_shuffle(self, x):batchsize, num_channels, height, width = x.data.size()assert num_channels % self.group == 0group_channels = num_channels // self.groupx = x.reshape(batchsize, group_channels, self.group, height, width)x = x.permute(0, 2, 1, 3, 4)x = x.reshape(batchsize, num_channels, height, width)return x

我们一步一步做好乐高积木然后将这些乐高积木拼装起来,如下: 

图(b)主分支代码如下: 

图(c)主分支代码如下:
图(c)侧分支代码如下:

channel shuffle代码: 

做好乐高积木之后,我们在forward函数中开始搭建这些乐高积木,如下所示: 

2.2 networks.py (ShuffleNetV1网络架构)

ShuffleNetV1网络架构: 

 ShuffleNetV1 网络架构代码:

import torch
import torch.nn as nn
from blocks import ShuffleV1Blockclass ShuffleNetV1(nn.Module):def __init__(self, input_size=224, n_class=1000, model_size='2.0x', group=None):super(ShuffleNetV1, self).__init__()print('model size is ', model_size)assert group is not Noneself.stage_repeats = [4, 8, 4]self.model_size = model_sizeif group == 3:if model_size == '0.5x':self.stage_out_channels = [-1, 12, 120, 240, 480]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 240, 480, 960]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 360, 720, 1440]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 480, 960, 1920]else:raise NotImplementedErrorelif group == 8:if model_size == '0.5x':self.stage_out_channels = [-1, 16, 192, 384, 768]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 384, 768, 1536]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 576, 1152, 2304]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 768, 1536, 3072]else:raise NotImplementedError# building first layerinput_channel = self.stage_out_channels[1]self.first_conv = nn.Sequential(nn.Conv2d(3, input_channel, 3, 2, 1, bias=False),nn.BatchNorm2d(input_channel),nn.ReLU(inplace=True),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.features = []for idxstage in range(len(self.stage_repeats)):numrepeat = self.stage_repeats[idxstage]output_channel = self.stage_out_channels[idxstage+2]for i in range(numrepeat):stride = 2 if i == 0 else 1first_group = idxstage == 0 and i == 0self.features.append(ShuffleV1Block(input_channel, output_channel,group=group, first_group=first_group,mid_channels=output_channel // 4, ksize=3, stride=stride))input_channel = output_channelself.features = nn.Sequential(*self.features)self.globalpool = nn.AvgPool2d(7)self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class, bias=False))self._initialize_weights()def forward(self, x):x = self.first_conv(x)x = self.maxpool(x)x = self.features(x)x = self.globalpool(x)x = x.contiguous().view(-1, self.stage_out_channels[-1])x = self.classifier(x)return xdef _initialize_weights(self):for name, m in self.named_modules():if isinstance(m, nn.Conv2d):if 'first' in name:nn.init.normal_(m.weight, 0, 0.01)else:nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)if m.bias is not None:nn.init.constant_(m.bias, 0.0001)nn.init.constant_(m.running_mean, 0)elif isinstance(m, nn.BatchNorm1d):nn.init.constant_(m.weight, 1)if m.bias is not None:nn.init.constant_(m.bias, 0.0001)nn.init.constant_(m.running_mean, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)if __name__ == "__main__":model = ShuffleNetV1(group=3)# print(model)test_data = torch.rand(5, 3, 224, 224)test_outputs = model(test_data)print(test_outputs.size())

分析: 

 

asset函数:
张量的连续性:https://blog.csdn.net/m0_48241022/article/details/132804698 
如何理解张量、张量索引等:https://blog.csdn.net/m0_48241022/article/details/132729561
torch.nn.Conv2d函数:
torch.nn.BatchNorm2d函数:
torch.nn.ReLU函数:
torch.nn.AvgPool2d函数:
torch.nn.Linear函数:
torch.nn.Sequential函数:
torch.cat函数:
permute函数:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1544295.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

为什么pmp证书只能对标cspm二级证书??

PMP跟CSPM一个是国际上的证书,一个是中国本土的证书。一个偏理论一个偏实践。对标来说,PMP的等级跟CSPM-2级一个等级,如果想对标更高级,那就要更高级的证书了。 现在有 PMP/PGMP 证书的可以不用参加考试就能换 CSPM -2 级跟 CSPM-…

2024年双十一买什么好?五款好物推荐!

​是不是很多朋友跟我一样,已经为双11做好了准备,打算开启买买买的节奏!作为一名家居兼数码博主,每年双11的时候都会疯狂囤很多物品,所以今天就跟大家来分享一下,我的双11购物清单,也给大家参考…

Linux-DHCP服务器搭建

环境 服务端:192.168.85.136 客户端:192.168.85.138 1. DHCP工作原理 DHCP动态分配IP地址。 2. DHCP服务器安装 2.1前提准备 # systemctl disable --now firewalld // 关闭firewalld自启动 # setenforce 0 # vim /etc/selinux/config SELINU…

解锁MySQL升级秘诀:提升性能、增强安全的必备指南

随着mysql不断演进,旧的版本不断地会发现新的漏洞,为修复漏洞体验新版本的功能,就需要对数据库进行升级操作。 升级注意点 备份!备份!备份! 1.从5.6升级到5.7需首先升级到5.6最新版;不支持跨…

Apache Iceberg 数据类型参考表

Apache Iceberg 概述-链接 Apache Iceberg 数据类型参考表 数据类型描述实例方法注意事项BOOLEAN布尔类型,表示真或假true, false用于条件判断,例如 WHERE is_active true。确保逻辑条件的正确性。INTEGER32位有符号整数42, -7可用于计算、聚合&#xf…

照片去水印怎么操作?3个高清壁纸无损去水印的教程分享

上网真好啊! 能够找到好多摄影大神分享的,超好看的自然景物照片,每一张都想拿来当电脑桌面壁纸、手机壁纸...... 但上网拿的照片有这点不好,就是大部分照片都带有防盗水印,虽说不影响照片的整体美观,但作为…

Android Studio报错 Cause connect timed out

Android Studio报错 Cause connect timed out 解决方法: 在gradle-wrapper.properties中更改distributionUrl为: distributionUrlhttps://mirrors.cloud.tencent.com/gradle/gradle-5.1.1-all.zip如果对你有帮助,就一键三连呗(关…

Win11+cuda11.7+spconv11.7搭建OpenPCdet

这里写自定义目录标题 前面詳細的教程參考:https://blog.csdn.net/xuegreat1/article/details/141892867 懶得寫了,先寫遇到的一些教程外的bug: 上文教程走完后運行demo.py,但是發現沒有裝mayavi庫,直接安裝報錯&#…

项目实战总结-Kafka实战应用核心要点

Kafka实战应用核心要点 一、前言二、Kafka避免重复消费2.1 消费者组机制2.2 幂等生产者2.3 事务性生产者/消费者2.4 手动提交偏移量2.5 外部存储管理偏移量2.6 去重逻辑2.7 幂等消息处理逻辑2.8 小结 三、Kafka持久化策略3.1 持久化文件3.2 segment 分段策略3.3 数据文件刷盘策…

迎国庆-为祖国庆生python、Java、C各显神通

" 金秋送爽,丹桂飘香“,我们即将即将迎来祖国母亲的华诞!! 七十余载风雨兼程,无数先辈以热血铸就辉煌,换来了今日的繁荣昌盛。从东方破晓的第一缕曙光,到星辰大海的无限探索,中…

git 删除 git push 失败的记录

文章目录 问题分析 问题 git push 失败后如何清理 commit 提交的内容 当我们 git push 失败后,如果下次有新的改动需要push时,会出现如下报错 分析 找到需要回退的那次commit的 哈希值 git log然后就回退到了指定版本,这个时候再把新修改…

解析rss链接数据,来长期把某博客数据订阅到自己的网站

目的 当我们打开这个订阅链接,会看到我们的文章信息以xml的形式呈现到浏览器页面中,怎么直接在我们自己的网站中,将这个链接的数据转为我们熟悉的json数据,然后渲染到自己的网站中呢 技术栈 react hookstypescriptwebpack 核心…

【C++掌中宝】深入理解函数重载:概念、规则与应用

文章目录 引言1. 什么是函数重载?2. 为什么需要函数重载?3. 编译器如何解决命名冲突?4. 为什么返回类型不参与重载?5. 重载函数的调用匹配规则6. 编译器如何解析重载函数的调用?7. 重载的限制与注意事项8. 总结结语 引…

柯桥小语种学习之语言交流 | 德语餐厅用语

01 一、入座与点餐 1. Guten Tag! Ein Tisch fr zwei Personen, bitte.(你好!请给我们一张两人桌。) 2. Knnen wir hier sitzen?(我们可以坐这里吗?) 3. Die Speisekarte, bitte.(请给我菜…

在Windows系统上安装的 zlib C++ 库

在Windows系统上安装的 zstd C 库 项目地址步骤步骤一步骤二步骤三如果生成过程中遇到如下错误: 效果 项目地址 https://github.com/madler/zlib 可以发现这个项目有CMakeLists.txt文件,那就比较好搞了 步骤 步骤一 git clone gitgithub.com:madler/zlib.git步骤二 cd zli…

丢失照片/消息/文件,当发现没有备份 Android 手机数据时急救方法

当人们发现他们没有备份 Android 手机数据时,通常为时已晚。但是,我们都不想永久丢失珍贵的照片, 消息和其他文件。这就是为什么您应该检查 遵循 5 大免费 Android 数据恢复工具和最佳替代品 他们。 排名前五的免费 Android 数据恢复软件 1.奇…

黑芝麻A1000-Ubuntu20.04(九)yolov5从训练到板端运行过程详解

宿主机:台式电脑 Ubuntu20.04 开发板:A1000(烧录版本SDK v2.3.1.2) 模型转换容器:bsnn-tools-container-stk-4.2.0 编译容器:a1000b-sdk-fad-2.3.1.2 yolov5使用工程:黑芝麻根据https://github.…

PHP探索校园新生态校园帮小程序系统小程序源码

探索校园新生态 —— 校园帮小程序系统,让生活更精彩! 🌱【开篇:走进未来校园,遇见新生态】🌱 你是否厌倦了传统校园的繁琐与单调?是否渴望在校园里也能享受到便捷、智能的生活体验&#xff1…

3d可视化图片:通过原图和深度图实现

1、depthy 在线体验demo: https://depthy.stamina.pl/#/ 也可以docker安装上面服务: docker run --rm -t -i -p 9000:9000 ndahlquist/depthy http://localhost:90001)首先传原图 2)再传对应深度图 3)效果 </ifra

网络事件管理

网络事件管理是运行组织 IT 网络不可或缺的一部分&#xff0c;网络事件管理的最终目标很简单&#xff1a;在发生中断时尽快恢复服务或功能。但是为了高效和一致地进行&#xff0c;IT 运营团队需要时刻保持警惕&#xff0c;不断了解网络事件&#xff0c;并且必须系统地遵循一套程…