【python】三层BP神经网络推导&MNIST&优化效果对比&损失函数对比

论坛 期权论坛     
选择匿名的用户   2021-5-26 12:45   473   0
<h1>一、概述</h1>
<p>本文的推导参见西瓜书P102~P103,代码参见<a href="https://blog.csdn.net/ebzxw/article/details/81591437">该网址</a>。主要实现了利用三层神经网络进行手写数字的识别。</p>
<h1>二、理论推导</h1>
<h2>1、参数定义</h2>
<p><img alt="" class="blockcode" height="308" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-05eeae19d5a83ec3de17a3accb98e8c2.jpg" width="567"></p>
<p>三层神经网络只有一层隐藏层。参数如下:</p>
<table align="center" border="1" cellpadding="1" cellspacing="1" style="width:500px;"><tbody><tr><td style="text-align:center;">x</td><td style="text-align:center;width:300px;">输入层输入</td></tr><tr><td style="text-align:center;">v</td><td style="text-align:center;width:300px;">输入层与隐藏层间的权值</td></tr><tr><td style="text-align:center;">α</td><td style="text-align:center;width:300px;">隐藏层输入</td></tr><tr><td style="text-align:center;">b</td><td style="text-align:center;width:300px;"> <p>                           隐藏层输出</p> </td></tr><tr><td style="text-align:center;"> <p>                      w</p> </td><td style="text-align:center;width:300px;">隐藏层与输出层间的权值</td></tr><tr><td style="text-align:center;"><img alt="\beta" class="mathcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-6dbd45f6c0b67511494a8782ec5fc360.latex"></td><td style="text-align:center;width:300px;">输出层输入</td></tr><tr><td style="text-align:center;"><img alt="\hat y" class="mathcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-62c84e9c298386b857c8c9e613ef3987.latex"></td><td style="text-align:center;width:300px;">输出层输出</td></tr></tbody></table>
<p>参数关系如下:</p>
<p><img alt="\alpha_h&#61; \sum_{i&#61;1}^{d}{v_{ih}x_i}" class="mathcode" src="https://private.codecogs.com/gif.latex?%5Calpha_h%3D%20%5Csum_%7Bi%3D1%7D%5E%7Bd%7D%7Bv_%7Bih%7Dx_i%7D"></p>
<p><img alt="b_h&#61;f(\alpha_h)" class="mathcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-d988b26d7d65b793d09687c120a48550.latex"></p>
<p><img alt="\beta_j&#61;\sum^{q}_{h&#61;1}{w_{hj}b_h}" class="mathcode" src="https://private.codecogs.com/gif.latex?%5Cbeta_j%3D%5Csum%5E%7Bq%7D_%7Bh%3D1%7D%7Bw_%7Bhj%7Db_h%7D"></p>
<p><img alt="\hat y&#61;f(\beta_j)" class="mathcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-e39a7d0056a6da123f0e6533caab4e91.latex"></p>
<p>上述等式中fx为激活函数。西瓜书默认激活函数为sigmoid,损失函数为均方根,本文以此为前提进行推导。</p>
<h2>2、推导</h2>
<p>设损失函数为<img alt="E" class="mathcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-cb9cfb02d007b4a2d0e16d0f47bf4713.latex">,则其公式如下:</p>
<p><img alt="E&#61;\frac{1}{2}\sum^{l}_{j&#61;1}({y_j-\hat y_j})^2" class="mathcode" src="https://private.codecogs.com/gif.latex?E%3D%5Cfrac%7B1%7D%7B2%7D%5Csum%5E%7Bl%7D_%7Bj%3D1%7D%28%7By_j-%5Chat%20y_j%7D%29%5E2"></p>
<p>对<img alt="w_{hj}" class="mathcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-7fad34506e22255718bf344a125174c9.latex">求偏导如下,这愚蠢的CSDN不支持多行公式编辑,所以只好手写了:</p>
<p><img alt="" class="blockcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-1e03e57c343c9b3051e9cf55e884cb1a.png"></p>
<p>于是我们就得到了<img alt="w_{hj}" class="mathcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-7fad34506e22255718bf344a125174c9.latex">的更新公式:</p>
<p><img alt="\Delta w_{hj}&#61;\eta(y_j-\hat y_j)\hat y_j(1-\hat y_j)b_h" class="mathcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-efd6b00f441de26a7e7a1297611bf644.latex"></p>
<p>同样的,对<img alt="v_{ih}" class="mathcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-738292f752e55ccd6839926be318587c.latex">求偏导如下:</p>
<p><img alt="" class="blockcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-0875d0338e1467668ec88165f3332afd.png"></p>
<p>于是我们就得到了<img alt="v_{ih}" class="mathcode" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-738292f752e55ccd6839926be318587c.latex">的更新公式:</p>
<p><img alt="\Delta v_{ih}&#61;\eta x_i b_h(1-b_h)\sum^l_{j&#61;1}((y_j-\hat y_j)\hat y_j(1-\hat y_j)w_{hj})" class="mathcode" src="https://private.codecogs.com/gif.latex?%5CDelta%20v_%7Bih%7D%3D%5Ceta%20x_i%20b_h%281-b_h%29%5Csum%5El_%7Bj%3D1%7D%28%28y_j-%5Chat%20y_j%29%5Chat%20y_j%281-%5Chat%20y_j%29w_%7Bhj%7D%29"></p>
<h1>三、代码实现</h1>
<p>优化方法选择SGD。</p>
<h2>1、数据集初始化</h2>
<p>MNIST数据集可以在TensorFlow中下载到:</p>
<pre class="blockcode"><code class="language-python">import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist &#61; input_data.read_data_sets(&#34;./MNIST_data/&#34;, one_hot&#61;True)</code></pre>
<p>mnist对象中就存储着所有的数据,其中,mnist.train.images为50000*784的二维array;储存着训练集的输入,每一行储存着784个像素;mnist.train.labels为50000*10的二维array;储存着训练集的标记,每一行中为1
分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:3875789
帖子:775174
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP