Facebook DeiT(Data-efficient Image Transformers) 解析

Luna
Written by Luna on

    今天看Facebook AI的DeiT。比起Moco v3训练出来的模型,DeiT胜在模型小,训练和推理都更加迅速。且精准度还有了很大提高。

    迅速到什么程度呢?用一个8-GPU服务器,训练3天就可以(其中预训练占用了53小时,fine-tuning占20个小时)。

    看下图的效能比对,可以发现,DeiT确实是强悍的。

    接下来我们看DeiT的具体原理,首先,它的主干模型还是ViT,关于ViT可以看这篇:ViT (Vision Transformer)原理及代码解析,这里就不赘述了。

    它可以训练的这么快的原因是用了蒸馏技术(distillation)。 那什么是蒸馏技术呢?就是把一个模型的知识往另外一个模型迁移的过程。被迁移知识的那个模型,我们叫他teacher,是训练好的。从别的模型学到知识的模型,我们叫他student。通常是把大模型的知识往小模型上迁移。以期在不过分损失精准度的前提下,使推理速度大大提高。迁移的方式通常是让小模型和大模型有一样的输出。接下来看两种蒸馏方式:

软蒸馏(soft distillation)

    Zt是teacher模型的输出,Zs是student模型的输出。ψ表示的是softmax,KL是KL散度(Kullback-Leibler),又叫相对熵,Lce是交叉熵(cross-entropy)。y是真实标签(ground truth label)。

    这里参考Cross Entropy Loss简单说一下KL散度和交叉熵的区别,这里就不上公式了,用大白话说,如果觉得有公式更容易理解,可以看参考文章中的公式。

    首先说一下信息量的概念,信息量指的是一个事件发生含的信息量,发生的概率越小,含的信息量就越大,比如太阳有天从西边升起来了,那么这个事件含的信息量就超级大了,全世界都得炸。而计算机里的信息量的大小指的是,描述一个事件发生所需要的位元(bits)。这里的描述和我们通常的语言描述是不一样的。而信息熵是信息量的期望值,还是太阳升起这个问题,它有可能从东边升起,也有可能从西边,南边,北边升起,虽然概率无限趋近于0,但是墨菲定律告诉我们,只要有概率,就一定会发生。那么太阳升起的方向这个事件有一个信息量的期望值,我们怎么理解期望值呢?太阳升起是从东南西北哪个方向呢,每个方向都会由一个概率,假设这件事情遵循特定的概率分布,如果这件事情发生无数次,那么平均每次我们需要用多少个位元来描述这件事情呢。平均每次,事件携带的信息量的大小,就是信息熵

    接下来我们解释相对熵,也就是KL散度。比如太阳升起是从东南西北哪个方向,假设这件事情遵循特定概率分布p,假设我们不知道这个p,现在我们自己估一个概率分布q,是根据模型或者自己的认知设定的,那么用真实的概率分布会有一个信息熵,用我们估计的概率分布也会有一个信息熵,用估计的信息熵减去真实的信息熵,就是相对熵。相对熵计算的是,如果用我们估计的这个概率分布来替代真实的概率分布,如果这个事件会发生无数次,那么平均传输和描述这个事件需要多出多少位元。

    而交叉熵就是,如果用这个估计的分布替代真实的分布,如果这个事件发生无数次,平均需要多少位元来描述和传输这个事件。

    这两个都可以用来比较两个分布的差异,差异越大,相对熵或者交叉熵就越大,差异越小,相对熵或者交叉熵就越小。

    到这里我们来理解一下,为什么上面的公式,前面用的是交叉熵,后面用的是相对熵,因为前面y是固定的,不管我们怎么对样本进行变化,y都是不变的,因此由真实概率分布所得的信息熵是固定的,没有必要再去减一个固定值。但是后面的ψ(Zt/τ)则不一样了,如果我们用样本增强,或者我们不止有一个teacher,改变τ的值,那么ψ(Zt/τ)的值就会发生变化,信息熵也会发生变化,出于这种考虑,用相对熵保留差异部分可能可以更好的排除一些其他因素对loss造成的干扰。

    当然,我这种理解不一定对。

硬蒸馏 (hard-label distillation): 

    与软蒸馏不同的地方是,这里的yt是teacher的hard decision,yt = argmaxcZt(c),也就是说yt不再是一个概率分布,而是一个根据最高概率做出的一个决策结果(但其实也可以看成一个概率分布,这里只是相对而言)。

    与软蒸馏不同的,还有损失函数的选择。在硬蒸馏里,yt和y的关系是对等的,对结果起到的作用是对等的。大家会更偏好硬蒸馏,硬蒸馏参数更少,容易理解。这里的yt也会因为数据增强而产生不同结果,不是一定的。

    接下来看,DeiT是怎么实现软硬蒸馏的。

    在ViT的结构中,增加了一个和class token一样功能的distillation token。具体的代码实现可以看ViT (Vision Transformer)原理及代码解析,在一般的硬蒸馏里,用来和y计算差异的Zs和用来和yt计算差异的Zs是一致的,但是在DeiT里,分开了。

# https://github.com/facebookresearch/deit/blob/main/models.py        
        if self.training:
            # x是class token, x_dist是distillation token 
            return x, x_dist
        else:
            # during inference, return the average of both classifier predictions
            return (x + x_dist) / 2

    根据DeiT模型里的代码,训练的时候返回的是cls_token,和distillation token,而在推理的时候,返回的是cls_token和distillation_token的均值。也就是说DeiT中的class token和distillation_token被认定有等价的推理价值。取均值会让推理更准确。

# https://github.com/facebookresearch/deit/blob/main/losses.py    
    def forward(self, inputs, outputs, labels):
        """
        Args:
            inputs: The original inputs that are feed to the teacher model
            outputs: the outputs of the model to be trained. It is expected to be
                either a Tensor, or a Tuple[Tensor, Tensor], with the original output
                in the first position and the distillation predictions as the second output
            labels: the labels for the base criterion
        """
        outputs_kd = None
        if not isinstance(outputs, torch.Tensor):
            # assume that the model outputs a tuple of [outputs, outputs_kd]
            # class token,distillation token
            outputs, outputs_kd = outputs
        # class token和groud-truth labels产生base loss
        base_loss = self.base_criterion(outputs, labels)
        if self.distillation_type == 'none':
            return base_loss


        if outputs_kd is None:
            raise ValueError("When knowledge distillation is enabled, the model is "
                             "expected to return a Tuple[Tensor, Tensor] with the output of the "
                             "class_token and the dist_token")
        # don't backprop throught the teacher
        with torch.no_grad():
            teacher_outputs = self.teacher_model(inputs)


        if self.distillation_type == 'soft':
            T = self.tau
            # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
            # with slight modifications
            distillation_loss = F.kl_div(
                F.log_softmax(outputs_kd / T, dim=1),
                #We provide the teacher's targets in log probability because we use log_target=True 
                #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
                #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
                F.log_softmax(teacher_outputs / T, dim=1),
                reduction='sum',
                log_target=True
            ) * (T * T) / outputs_kd.numel()
            #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 
            #But we also experiments output_kd.size(0) 
            #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
        elif self.distillation_type == 'hard':
            distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))


        loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
        return loss

    看DeiT中的代码,软蒸馏和硬蒸馏都被定义了,inputs是样本,由teacher_model产生teacher_outputs,outputs中包含了class token和distillation token,class token和ground-truth label产生base_loss,而distillation token和teacher model产生的结果产生distillation_loss。

    DeiT的原理大概说清楚了,实验的话,实在是懒,作者做了大量实验,如果想用DeiT,还是建议仔细看实验的。之后可能会写一篇,怎么用DeiT做分类和Transfer的文章。期待的话就关注吧。

Comments

comments powered by Disqus