IA-RED2: Interpretability-Aware Redundancy Reduction for Vision Transformer

Bowen Pan1, Rameswar Panda2, Yifan Jiang3, Zhangyang Wang3,
Rogerio Feris2, Aude Oliva1,2

1 MIT CSAIL,   2 MIT-IBM Watson AI Lab,   3 UT Austin

[Paper]       [Code (coming soon)]       [Interpretation Tool]


We propose a novel Interpretability-Aware REDundancy REDuction (IA-RED2) framework for reducing the redundancy of vision transformers. The key mechanism that IA-RED2 uses to increase efficiency is to dynamically drop some less informative patches in the original input sequence so that the length of the input sequence could be reduced. While the original vision transformer tokenizes all of the input patches, it neglects the fact that some of the input patches are redundant and such redundancy is input-dependant (see from Figure). We leverage the idea of dynamic inference, and adopt a policy network (referred to as multi-head interpreter) to decide which patches are uninformative and then discard them. Our proposed method is inherently interpretability-aware as the policy network learns to discriminate which region is crucial for the final prediction results.


Method

Teaser

Illustration of our IA-RED2 framework. We divide the transformer into D groups. Each group contains a multi-head interpreter and L combinations of the MSA and FFN. Before input to the MSA and FFN, the patch tokens will be evaluated by the multi-head interpreter to drop some uninformative patches. The multi-head interpreters are optimized by reward considering both the efficiency and accuracy.


Qualitative Results of Interpretability-Aware Heatmaps

Teaser

Qualitative result of the heatmaps which hightlight the informative region of the input images of MemNet, raw attention at the second block, and our method with DeiT-S model. We find that our method can obviously better interpret the part-level stuff of the objects of interest.


Qualitative Results of Redundancy Reduction

Teaser

We visualize the hierarchical redundancy reduction process of our method with the DeiT-S model. The number on the upper-left corner of each image indicates the ratio of the remaining patches. From left to right, we can see that the network drops the redundant patches and focuses more on the high-level features of the objects.


Interpretation Demo

We further provide an interpretation tool for the reader who want to play the interpretability of our model. The usage of the tool is quite simple: python interpreter.py -p {image_path} -o {output_dir}. An environment with Python==3.6 (or above), torch==1.7 (or above) and timm==0.3.2 (or above) installed is required to run the tool. [Download]


Reference

B. Pan and Y. Jiang and R. Panda and Z. Wang and R. Feris and A. Oliva. IA-RED2: Interpretability-Aware Redundancy Reduction for Vision Transformer. arXiv 2021 [Bibtex]


Acknowledgements

We thank IBM for the donation to MIT of the Satori GPU cluster. This work is supported by the MIT-IBM Watson AI Lab and its member companies, Nexplore and Woodside.