利用Pytorch进行CNN分析

这篇文章主要介绍“利用Pytorch进行CNN分析”,在日常操作中,相信很多人在利用Pytorch进行CNN分析问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”利用Pytorch进行CNN分析”的疑惑有所帮助!接下来,请跟着小编一起来学习吧!

工具

开源深度学习库: PyTorch

数据集: MNIST

实现

初始要求

利用Pytorch进行CNN分析

首先建立基本的BASE网络,在Pytorch中有如下code:

class Net(nn.Module):     def __init__(self):         super(Net, self).__init__()         self.conv1 = nn.Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), padding=0)         self.conv2 = nn.Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1), padding=0)         self.fc1 = nn.Linear(4*4*50, 500)         self.fc2 = nn.Linear(500, 10)      def forward(self, x):         x = F.max_pool2d(self.conv1(x), 2)         x = F.max_pool2d(self.conv2(x), 2)         x = x.view(-1, 4*4*50)         x = F.relu(self.fc1(x))         x = self.fc2(x)         return F.log_softmax(x)

这部分代码见 base.py .

<>强问题答:预处理

利用Pytorch进行CNN分析

即要求将MNIST数据集按照规则读取并且转变到适合处理的格式。这里读取的代码沿用了BigDL Python,支持的读取方式,无需细说,根据MNIST主页上的数据格式可以很快读出,关键块有读取32位比特的函数:

 def  _read32 (bytestream):,,,,, dt =, numpy.dtype (numpy.uint32) .newbyteorder(& # 39;祝辞& # 39;),,,,#,大端模式读取,* * *字节在前(MSB 第一),,,,,return  numpy.frombuffer (bytestream.read (4), dtype=dt) [0] 

读出后是(N, 1, 28日,28)的张量,每个像素是0 - 255的值,首先做一下归一化,将所有值除以255年,得到一个0 - 1的值,然后再正常化,训练集和测试集的均值方差都已知,直接做即可。由于训练集和测试集的均值方差都是针对归一化后的数据来说的,所以刚开始没做归一化,所以向前输出和研究生很离谱,后来才发现是这里出了问题。

这部分代码见预处理。py .

<>强问题B:基模型

利用Pytorch进行CNN分析

将随机种子设置为0,在前10000个训练样本上学习参数,* * *看20个时代之后的测试集错误率。* * *结果为:

 Test 集:,Average 失:,0.0014,准确性:,9732/10000  (97.3%) 

可以看的到,基模型准确率并不是那么的高。

<强>问题C:批处理规范化v。null

利用Pytorch进行CNN分析