最大似然法监督分类步骤_图解半监督学习FixMatch,只用10张标注图片训练CIFAR10...

论坛 期权论坛     
选择匿名的用户   2021-5-30 16:55   247   0
<div class="._5ce-wx-style" style="font-size:16px;">
<div class="rich_media_content" id="js_content">
  <strong>导读</strong>
  <p>仅使用10张带有标签的图像,它在CIFAR-10上的中位精度为78%,最大精度为84%,来看看是怎么做到的。</p>
  <p>深度学习在计算机视觉领域展示了非常有前途的结果。但是当将它应用于实际的医学成像等领域的时候,标签数据的缺乏是一个主要的挑战。</p>
  <p>在实际环境中,对数据做标注是一个耗时和昂贵的过程。你有很多的图片,由于资源约束,只有一小部分人可以进行标注。在这样的情况下,我们如何利用大量未标注的图像以及部分已标注的图像来提高我们的模型的性能?答案是semi-supervised学习。</p>
  <p>FixMatch是Google Brain的Sohn等人最近开发的一种半监督方法,它改善了半监督学习(SSL)的技术水平。它是对之前的方法(例如UDA和ReMixMatch)的简单组合。在本文中,我们将了解FixMatch的概念,并看到仅使用10张带有标签的图像,它在CIFAR-10上的中位精度为78%,最大精度为84%。</p>
  <h2><span style="font-weight:bold;">FixMatch背后的直觉</span></h2>
  <p>假设我们正在对猫与狗进行分类,但是我们的标签数据有限,并且有很多未标签的猫狗图像。</p>
  <figure style="text-align:center;">
   <img alt="4b795dbcf7d134d91b8768b648804eaf.png" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-822e1c26099b21c1c425569e10017e76.png">
  </figure>
  <p>我们通常的“监督学习”方法将是仅在标注图像上训练分类器,而忽略未标注的图像。</p>
  <p><img alt="f28cb21d72b5d37a45f6299144377cfc.png" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-342f7628193efae4a1f4c4bfae86103c.png"></p>
  <p>除了忽略未标注的图像,我们还可以应用以下方法。我们知道模型也应该能够处理图像的扰动,从而提高泛化能力。</p>
  <blockquote>
   <p>如果我们对未标注的图像进行图像增强,并让监督模型预测这些图像会怎么样?由于是同一张图片,因此两者的预测的标签应该相同。</p>
  </blockquote>
  <p><img alt="1a993ef94db021ca5c2c65e880dd8c4a.png" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-148fee75c50df3dc6ae9a4c7b0a913ef.png"></p>
  <p>因此,即使不知道其正确的标签,我们也可以将未标注的图像用作训练流水线的一部分。这是FixMatch及其之前的许多论文背后的核心思想。</p>
  <h2><span style="font-weight:bold;">FixMatch的Pipeline</span></h2>
  <p>凭直觉,让我们看看如何在实践中实际应用FixMatch。下图总结了整个pipeline:</p>
  <p><img alt="2587d72bb466a4b985c1b0b19f5e2920.png" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-a0f4fa8b0d86af273852c5f588eb766a.png"></p>
  <p>如图所示,我们使用交叉熵损失在标注的图像上训练了监督模型。对于每个未标注的图像,使用弱增强和强增强获得两个图像。弱增强图像被传递到我们的模型中,我们得到了关于类的预测。将最有信心的类别的概率与阈值进行比较。如果它高于阈值,那么我们将该类作为ground truth的标签,即伪标签。然后,将经过强增强的图像传递到我们的模型中,获取类别的预测。使用交叉熵损失将此概率分布与ground truth伪标签进行比较。两种损失组合起来进行模型的更新。</p>
  <h2><span style="font-weight:bold;">Pipeline的组件</span></h2>
  <h3><span style="font-weight:bold;">1. 训练数据和增强</span></h3>
  <p>FixMatch借鉴了UDA和ReMixMatch的这一思想,应用不同的增强方法,即在未标注的图像上进行弱增强以生成伪标签,同时在未标注图像上进行强增强以进行预测。</p>
  <p><strong>a. 弱增强</strong></p>
  <p>对于弱增强,本文使用标准的翻转和平移策略。它包括两个简单的增强:</p>
  <ul><li><p><strong>Random Horizontal Flip</strong></p></li></ul>
  <p><img alt="725fa7b69423c233de6283b5152f5e61.gif" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-7d41bc7079093367f864e9913c0f5033.gif"></p>
  <figure style="text-align:center;"></figure>
  应用此增强的概率为50%。对于SVHN数据集,将跳过此步骤,因为那些图像包含与水平翻转无关的数字。在PyTorch中,可以使用transforms执行以下操作:
  <pre class="blockcode"><code>from PIL import Imageimport torchvision.transforms as transforms<br>im &#61; Image.open(&#39;dog.png&#39;)<br>weak_im &#61; transforms.RandomHorizontalFlip(p&#61;0.5)(im)</code></pre>
  <ul><li><p><strong>随机水平和垂直移动</strong></p><p><img alt="b86ca0c95fa0c7a60fee2238c0b74772.gif" src="https://beijingoptbbs.oss-cn-beijing.aliyuncs.com/cs/5606289-803820f67404a31edcda56f97c0aa21c.gif"></p><p>12.5%,在PyTorch中,可以
分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:3875789
帖子:775174
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP