DeiT (Training Data-efficient Image Transformers & Distillation through Attention)
1. Overview
Facebook Research에서 2020년 12월 공개; https://github.com/facebookresearch/deit
Knowledge Distillation 기법을 적용하여 대용량 데이터셋으로 pre-training하는 과정 없이 높은 성능 달성
대부분의 구조는 ViT와 동일하며, data augmentation, regularization 등 다양한 기법들을 적용하고 기존의 class token에 distillation token을 추가하여 학습
2. Knowledge Distillation
Summary
Hinton et al., "Distilling the Knowledge in a Neural Network" (NIPS 2014)
한줄요약: 청출어람 (big 네트워크에서 축적된 정보를 small 네트워크로 전달하여 small 네트워크에서도 big 네트워크와 비슷한 성능을 내는 것이 목적)
big = teacher, small = student
Teacher 모델의 Inductive bias를 soft한 방법으로 전달
Supervised Learning으로 학습한 모델의 output은 hard label이 아니라 logit에 대한 출력으로 다른 클래스들의 가중치도 포함되어 있음.
Objective
Teacher의 모델의 softmax 분포와 student 모델의 softmax 분포의 KL divergence를 최소화.
Student loss + Distillation loss
Cross Entropy between ground truth and student's hard predictions(standard softmax) + Cross Entropy between the student's soft predictions and the teacher's soft targets
Soft Distillation:
Hard Distillation:
Temperature
더 많은 knowledge를 전달하기 위한 방법으로 Temperature(T)가 높을 수록 더 soft한 확률분포를 얻을 수 있음 (T = 1일 때, Softmax와 동일);
T
Logits
Softmax Probs
1
[30, 5, 2, 2]
[1.0000e+00, 1.3888e-11, 6.9144e-13, 6.9144e-13]
10
[3, 0.5, 0.2, 0.2]
[0.8308, 0.0682, 0.0505, 0.0505]
3. Model Architecture
기본적인 형태는 ViT와 동일하며, KD를 위한 distillation token을 추출
class token과 distillation token의 코사인 유사도는 0.06이나 embedding을 수행 후 class embedding과 distillation embedding의 코사인 유사도를 계산하면 0.93
class embedding과 distillation embedding을 concatenate하거나 합산하는 방법(late fusion)이 가능

4. Experiments
Training Details
Models
DeiT-B: baseline model; D = 768, h = 12, d = D/h = 64
BeiT-B 384: 384x384 해상도로 파인튜닝 수행
DeiT(증류수 모양): DeiT w/ distillation
DeiT-S(Small), DeiT-Ti(Tiny): DeiT의 경량화 버전

Settings
Transformer 기반 모델은 하이퍼파라메터 설정에 민감하므로 여러 가지 설정들을 조합하여 실험
Stochastic depth, Mixup, Cutmix, Repeated Augmentation, Random Augmentation 등 적용
Repeated Augmentation(RA)
이미지 배치 에서 의 이미지들을 샘플링 후, 데이터를 회 변환하여 늘림
동일한 성격의 이미지들이 섞이므로, i.i.d scheme에 벗어나기에 Small Batch에서는 퍼포먼스가 저하되지만, Large Batch에서는 성능이 향상됨

Performance
CNN 기반의 teacher 네트워크가 transformer보다 top-1 accuracy가 높음

Hard distillation 전반적으로 좋고, class token과 distillation token을 동시 사용하는 것이 성능이 좀 더 좋음

Distillation 모델이 CNN 모델보다 좀 더 연관성이 높음 -> Convnet의 inductive bias가 잘 전달됨을 알 수 있음.

Accuracy
ImageNet 기반으로 학습 시, ViT보다 우세하고(top-1 +6.3%) EfficientNet에 근접한 성능을 보임
5. Implementation
Source: https://github.com/facebookresearch/deit/blob/main/losses.py
References
Movie Clip
Implementation
Last updated
Was this helpful?