center loss 论文学习

center loss框架

从网络的的框架来看,center loss的主要工作是下图中的“Discriminative Features”。
fig1

普通的网络框架,在反向传播的过程中,根据类别标签,会将不同的类别划分开。如“Separable Features”所示,一开始两种颜色是混杂的,通过改变网络参数,让不同颜色能被分类器分开,就达到了目的。而这个过程中,只对不同类有要求,同一类没有进行约束。
center loss则是让类内的输出结果更加集中。

为了展示实际的效果,作者在mnist上进行了测试,下图是softmax分类器前面增加的一层的参数,其维度为2,这样就可以进行可视化的显示。

$$F=WX$$

$X$是上一层的输出,维度为800(根据论文计算得到),$F$为施加center loss的全连接层的输出,维度为2。那么权重参数$F$为{800,2}的矩阵。
fig2
在没有采用center loss时,不同类别的输出图像是一种花瓣,其特点是同一类的方差较大。可以找到分界线将不同类别区分开,虽然花瓣外尖端与其他类间距很大,花瓣中心的区分很小,很容易造成错误,如橘色区域,红线表示分类线。
这里写图片描述

如何让同一类颜色更集中呢?文中采用了center loss:
centerloss
很简单,每个将输出点与这类中心点的距离累加作为损失。
回想方差公式:
v
是不是很类似?降低center loss其实也可以看作是降低同类的方差。

实现

推荐EncodeTS/TensorFlow_Center_Loss的代码,使用TensorFlow实现,且有详细的中文注释。

center loss流程大致为:

  1. 初始化权重中心centers,形状为[num_classes, len_features],中心值为0
  2. 在一次iteration中,获取mini-batch中每一个样本对应的中心值,centers_batch,形状为[batch_size, feature_length](使用tf.gather技巧)
  3. 计算loss,特征与中心features - centers_batch的l2范数
  4. 根据论文公式(3)(4)更新权重中心:
    在一个mini-batch中,某一类$j$出现了$n$次,分解来看:
    1. 属于该类的第$i$个样本与中心距离$c_j-x_i$
    2. 同理算出这个类出现的$n$次样本的距离,并汇总求和
    3. 除以$n+1$

loss
center loss