学习⼊门必备:MAML(背景+论⽂解读+代码分析)
⽂章⽬录
前⾔
就在今年三四⽉份,炒出了⼀个“元宇宙”的新名词,相信⼤家并不陌⽣吧,百度百科的解释:“元宇宙(Metaverse)是利⽤科技⼿段进⾏链接与创造的,与现实世界映射与交互的虚拟世界,具备新型社会体系的数字⽣活空间。说起来⽐较遥远,跟我们⽬前现实并不是特别直观,但是元学习(Meta Learning)这个概念已经被提出了很多年了,让我们⼀探究竟吧。
今天给⼤家分享⼀篇⽐较经典的⽂章,也是⼊门元学习的必看论⽂:MAML
论⽂题⽬:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
模型不可知元学习在深度⽹络快速⾃适应中的应⽤
论⽂是2017年发表在ICML上的,⽬前被引⽤量也超过4600+,值得⼤家进⾏学习。
背景
元学习简介
在我前⼀篇⽂章已经介绍过元学习的⼀些基本概念和与机器学习的区别,⼤家感兴趣的话可以看⼀下
元学习问题定义
下⾯就是将元学习定义为双层优化问题,这是⼀个新思路,希望能够对元学习有更深刻的理解。
⾸先将元训练集分为⽀持集(Support)和查询集(Query);w可以看成算法;θ可以认为是模型参数
在内层优化阶段(Inner loop),在⽀持集中,采⽤w算法,根据task的loss值表现,来进⾏优化θ参数,最终根据Ltask最⼩值,内层优化得到最优的θ’值。
在外层优化阶段(Outer loop),在查询集中,根据内层优化的最优θ’值,计算当前Lmeta的值,根据多个任务后,计算出最⼩的所有任务的总loss值来优化w参数,不断调整w算法,最终在所有任务中表现最优。
快速学习根据双层优化的思想,我们可以将元学习问题也是可以定义为⼀个双层优化的问题。
⼩样本学习(Few shot learning)
问题定义
⼈类⾮常擅长通过极少量的样本识别⼀个新物体,⽐如⼩孩⼦只需要书中的⼀些图⽚就可以认识什么是“斑马”,什么是“犀⽜”。在⼈类的快速学习能⼒的启发下,研究⼈员希望机器学习模型在学习了⼀定类别的⼤量数据后,对于新的类别,只需要少量的样本就能快速学习,这就是 Few-shot Learning 要解决的问题。
通俗理解:在训练阶段模型学习⼤量数据,在测试阶段通过少量的样本学习后,可以快速的学习样本特征。
元学习/⼩样本学习基本特征
论⽂解读
Abstract
通过论⽂题⽬,我们会有⼀个⼤致的了解,Model-Agnostic(模型⽆关)、Fast Adaptation(快速适应)、Deep Networks(深度⽹络),可以看出这篇⽂章是适⽤于深度⽹络并且提出⼀种与模型⽆关的通⽤框架。
主要内容:提出了⼀种与模型⽆关的元学习算法,它与任何⽤梯度下降训练的模型都是兼容的,并且适⽤于各种不同的学习问题【分类、回归和强化学习】
元学习的⽬标:训练⼀个关于各种学习任务的模型,这样就可以只使⽤少量的训练样本来解决新的学习任务。
具体⽅法:模型的参数被显式地训练,使得少量的梯度步长和来⾃新任务的少量训练数据将在该任务上产⽣良好的泛化性能。
实验结果:证明了该⽅法在两个少镜头图像分类基准上的最优性能。
Introduction
关键思想是训练模型的初始参数,使模型在通过⼀个或多个⽤来⾃新任务的少量数据计算的⼀个或多个梯度步骤更新参数后,在新任务上具有最⼤的性能。
从特征学习的观点来看,训练模型的参数使得⼏个梯度步骤,甚⾄单个梯度步骤就可以在新任务上产⽣良好结果的过程可以被视为构建⼴泛适⽤于许多任务的内部表⽰,如果内部表⽰适⽤于许多任务,只需稍微微调参数(例如,主要通过修改前馈模型中的顶层权重)就可以产⽣良好的结果。
我们的程序针对易于微调和快速调整的模型进⾏了优化,允许在适合快速学习的空间进⾏调整。
从动⼒系统的观点来看,我们的学习过程可以被视为最⼤化新任务的损失函数对参数的敏感度:敏感度较⾼时,对参数的微⼩局部更改可导致深度⽹络快速适应的模型不可知元学习在任务损失⽅⾯的⼤幅改善。
这项⼯作的主要贡献是⼀种简单的与模型和任务⽆关的元学习算法,该算法训练模型的参数,以便少量的梯度更新将导致在新任务上的快速学习。
Motivation
传统的模型就是随机初始化,这样⼀开始的参数需要很多步更新后才能够达到⽐较好的结果。所以在MAML中想要获得⼀个⽐较好的初始值他和她,只经过⼀步更新后,就能够获得对于当前任务⽐较好
的参数。【我们可以看右边的图,θ根据三个loss值得到对应的更新⽅向,最终经过⼀步更新后,获得⼀个⽐较好的初始值,适⽤于其他任务。】
Model-Agnostic Meta-Learning
训练能够实现快速适应的模型,这是⼀种经常被形式化为极少机会学习的问题设置。
元学习问题设定
⼩样本学习中常见的⼀个概念:N-way N-shot
N-way 的意思是N分类
N-shot是在学习的样本中,每个类只提供N个样本
例⼦:常见⼩样本学习分类数据集MiniImagenet,5-way 5-shot
⼩样本元学习的⽬标:训练⼀个只使⽤⼏个数据点和训练迭代就能快速适应新任务的模型。
实现⽬标:在⼀组任务的元学习阶段对模型或学习者进⾏训练,使得训练的模型可以仅使⽤少量的⽰例或试验来快速适应新的任务。
为了实现这⼀点,相当于定义⼀个模型 f ,使得对于输⼊的X任务,会产⽣a. 我们训练这个⽹络使得它可以适应不同的⽆限的任务。
在图像分类中,其中L()是损失函数,q()是样本的分布,定义公式:
主要过程:从P(T)任务分布中选取新任务T,在k-shot的情境下,使⽤k个样本训练模型,从q()分布中选取k个样本,⽣成对应任务T的L().与模型⽆关的元学习算法细节
根据上述元学习的背景中,MAML算法的具体细节正如上图所⽰。⾸先进⼊内层优化,在⽀持集中,根据当前的θ值情况,进⾏⼀步更新得到θ’,经过内层优化后得到较好的θ’后,在外层优化查询集中,根据在查询集中θ’的loss值情况,在进⾏外层的θ更新,最终得到⼀个较好的θ初始值。
与元学习不同的点是:1、w值就是θ值本⾝ 2:同时Ltask和Lmeta的Loss设计是⼀致的。
伪代码
f (x )=a
T =L x ,a ,q x {(11)(1)}
上图是MAML的伪代码,下⾯就⼩曾哥就继续带⼤家⼀起来分析
1、⾸先随机初始化θ值
2、然后从P(T)中取出任务Ti
3、进⼊内层优化,根据当前初始化的θ值,评估梯度变化,然后进⾏⼀步更新,得到更新后的θi’值。
4、继续再从P(T)中取出其他任务Ti
5、然后根据更新后的θi’,在查询集中计算loss值,最后根据所有任务的loss值之和来进⾏更新θ值情况。上述是MAML算法的基本流程,有助于帮助⼤家理解。
算法实例讲解