PARP: Prune, Adjust and Re-Prune
for Self-Supervised Speech Recognition

Cheng-I Jeff Lai1,2     Yang Zhang2*     Alexander H. Liu1*     Shiyu Chang2,4*
Yi-Lun Liao1     Yung-Sung Chuang1,3     Kaizhi Qian2     Sameer Khurana1
David Cox2     James Glass1
1MIT CSAIL  2MIT-IBM Watson AI Lab  3National Taiwan University  4UC Santa Barbara

NeurIPS 2021 Spotlight

[Paper]      [Code]      [Colab]      [Bibtex]

Skip to:    [Abstract]      [Approach Summary]      [Result Summary]      [A Hypothesis]      [Video & Poster]     

TLDR: A simple and efficient pruning method for sparse subnetwork discovery from self-supervised pre-trained initializations (wav2vec 2.0/XLSR-53) that can be finetuned to the same downstream low-resource ASR results. See illustration below.

Full Abstract:
Self-supervised speech representation learning (speech SSL) has demonstrated the benefit of scale in learning rich representations for Automatic Speech Recognition (ASR) with limited paired data, such as wav2vec 2.0. We investigate the existence of sparse subnetworks in pre-trained speech SSL models that achieve even better low-resource ASR results. However, directly applying widely adopted pruning methods such as the Lottery Ticket Hypothesis (LTH) is suboptimal in the computational cost needed. Moreover, we show that the discovered subnetworks yield minimal performance gain compared to the original dense network.
        We present Prune-Adjust-Re-Prune (PARP), which discovers and finetunes subnetworks for much better performance, while only requiring a single downstream ASR finetuning run. PARP is inspired by our surprising observation that subnetworks pruned for pre-training tasks need merely a slight adjustment to achieve a sizeable performance boost in downstream ASR tasks. Extensive experiments on low- resource ASR verify (1) sparse subnetworks exist in mono-lingual/multi-lingual pre-trained speech SSL, and (2) the computational advantage and performance gain of PARP over baseline pruning methods.
        In particular, on the 10min Librispeech split without LM decoding, PARP discovers subnetworks from wav2vec 2.0 with an absolute 10.9%/12.6% WER decrease compared to the full model. We further demonstrate the effectiveness of PARP via: cross-lingual pruning without any phone recognition degradation, the discovery of a multi-lingual subnetwork for 10 spoken languages in 1 finetuning run, and its applicability to pre-trained BERT/XLNet for natural language tasks.


Approach Summary


A lesson from the past:
Recent progress in self-supervised speech representations (e.g. the wav2vec series, the Unispeech series, BigSSL, etc) has proven the importance of scaling up the representation modules to attain SOTA ASR performance. Additionally, the SUPERB benchmark shows that regardless of the SSL objective, model scaling is critical across 10 downstream speech tasks. Basically for a given SSL objective, the bigger the model, the better the downstream results are!

This leads us to a series of research questions:
Given a self-supervised speech representation framework on the right, such as wav2vec 2.0/XLSR-53, do there exist sparse subnetworks within the self-supervised pre-trained initializations with similar downstream low-resource ASR performance? What are the properties of these sparse subnetworks, and what can we learn from these properties? Beyond applying the pruning methods from the vision community, is there a more effient approach to discover these sparse subnetworks?

Our Goal:
Discover sparse subnetworks within pre-trained speech SSL for low-resource ASR such that (1) they are found and finetuned efficiently within 1 ASR finetuning pass, and (2) attain the same or even lower WER compared to original dense speech SSL model.


A Surprising Observation:
For any downstream spoken languages, the non-zero ASR pruning masks obtained from task-agnostic subnetwork discovery has high overlaps with those obtained from task-aware subnetwork discovery.
For instance on the right, applying unstructured unstructured pruning (UMP) on a wav2vec2 finetuned for Spanish and applying UMP again on a wav2vec2 finetuned for French result in 97% overlap in their pruning patterns. This holds for any amount of downstream supervision, pre-training model scale, and sparsity.

The flip side of the observation is: task-agnostic pruning/subnetwork provides a good basis for task-dependent pruning/subnetwork.


The Algorithm:
Step 1: Language-Agnostic Initial Subnetwork. Directly prune pre-trained SSL such as wav2vec2/XLSR at target sparsity, and obtain an initial subnetwork and an initial pruning mask. Alternatively, prune a non-target language finetuned wav2vec2/XLSR.
Step 2: Language-Aware Subnetwork Adjustment. Finetune the initial subnetwork on target downstream task/language. During finetuning, zero out the pruned weights specified by the pruning mask, but allow the weights be updated by gradient descent during backpropogation. After a few number of model updates, re-prune the updated subnetwork at target sparsity again.

In practice, we repete Step 2 multiple times (within 1 full ASR downstream finetuning). Figure below illustrates the 2 Steps.



Results Summary




Pruning for Low-Resource English ASR:
PARP (black line) out-performs or matches baseline pruning methods in all settings. We also found that at the 10min split w/o LM decoding, sparse subnetworks from wav2vec2 found with PARP attains an absolute 10% WER reduction over the full wav2vec2.




Pruning for Low-Resource Multi-Lingual ASR:
Pruned wav2vec2/XLSR for 10 spoken languages in low-resource conditions: Spanish (es), French (fr), Italian (it), Kyrgyz (ky), Dutch (nl), Russian (ru), Swedish (sv-SE), Turkish(tr), Tatar (tt), and Mandarin (zh-TW). Observe the similar pruning curves across languages, and PARP (black and pink lines) again out-performs or matches baseline pruning methods in all settings.



Cross-Lingual Subnetwork Transfer for wav2vec 2.0:
We investigate the transferability of sparse subnetwork discovered for a source language by finetuning its subnetwork on another target language.
Upper Left: transfer with regular finetuning leads to increase in PER over same-pair (no language mismatch) transfer.
Upper Right: transfer with PARP leads to no increase (sometimes even lower!) in PER over same-pair transfer. For example, a Spanish sparse subnetwork from wav2vec2/XLSR can be efficiently adapted for French w/o PER increase.




Cross-Task Subnetwork Transfer for BERT/XLNet:
We investigate the transferability of sparse subnetwork discovered for a source natural language task from GLUE by finetuning its subnetwork on another natural language task.
This experiment shows the applicability of PARP to natural language domains.




Cross-Task Subnetwork Transfer for wav2vec 2.0:
We investigate the transferability of sparse subnetwork discovered for a source speech task from SUPERB by finetuning its subnetwork on another speech task.
This experiment shows the (potential) applicability of PARP to other speech tasks such as spoken language understanding or speaker recognition.



Pruning as an Alternative for Representation Probing:
We empirically found that pruning/sparsity has a nice correspondance to the quality of representation for downstream tasks.
Top Table: pruned weight localization across layers. We found consistently that middle layers are pruned less, while initial and final layers are pruned much more.
Bottom Plot: plot borrowed from wav2vec-U, where it shows that middle layer representations are more valuable for downstream ASR.

A Hypothesis


A new insight into self-supervised representation learning:
We hypothesize the existence of downstream task/language-specific weights in self-supervised pre-trained initializations. These important weights only account for very few of the overall model weights, either in pre-trained wav2vec2/XLSR/BERT/XLNet. Therefore, for a given task/language, most of the pruned weights can be obtained "freely" with task-agnostic pruning. The adjustment step then "recovers" the accidentally pruned out important weights by reviving them with gradient updates, and since there are so few of them, the adjustment could be done efficiently.

Presentations

Reference

CIJ Lai, Y Zhang, AH Liu, S Chang, YL Liao, YS Chuang, K Qian, S Khurana, D Cox, J Glass. PARP: Prune, Adjust and Re-Prune for Self-Supervised Speech Recognition.
NeurIPS, 2021.

@inproceedings{lai2021parp,
  title={PARP: Prune, Adjust and Re-Prune for Self-Supervised Speech Recognition.},
  author={Lai, Cheng-I Jeff and Zhang, Yang and Liu, Alexander H and Chang, Shiyu and Liao, Yi-Lun and Chuang, Yung-Sung and Qian, Kaizhi and Khurana, Sameer and Cox, David and Glass, James},
  booktitle={NeurIPS},
  year={2021}
}

Broader Impact: The broader impact of this research work is making speech technologies more accessible in two orthogonal dimensions: (i) extending modern-day speech technology to many under-explored low-resource spoken languages, and (ii) introducing a new and flexible pruning technique to current and future speech SSL frameworks that reduces the computational costs required for adapting (finetuning) them to custom settings. We do not see its potential societal harm.

Acknowledgements: We thank IBM for the donation to MIT of the Satori GPU cluster, and John Cohn for maintaining the cluster. We also thank Lucy Chai, Wei-Ning Hsu, Desh Raj, Shu-wen Leo Yang, Abdelrahman Mohamedm, Erica Cooper, and anonymous reviewers for helpful suggestions and paper editing. This work is part of the low-resource language learning project funded by the MIT-IBM Waston AI Lab. Website template borrows from here.