Resource-Efficient Transformer Pruning for Finetuning of Large Models
2024-07-22
1. Introduction
Intro
- LLMs are pre-trained over very large datasets, and these foundations models (PLMs, pre-trained models) are finetuned further on a downstream task/dataset to maximize task-specific performance upon deployment.
Problem Statement
- Over-parameterization and using big datasets are known to result in better generalization performance when training PLMs. → Very big PLMs that require hundreds of GBs of GPU memory for finetuning and inference.
-
Effectively fineutning a PLM with limited resources while maintaining performance remains an open challenge.
- Pruning
- Based on the argument that not all weights of PLM are necessary for a given downstream task.
- Representative pruning techniques- or post-finetune pruning, performs better because pre-finetune pruning may yield suboptimal results by removing important parameters for the finetuning task.
- Many existing approaches still need some form of finetuning.
Contribution and Scope: RECAP
1) Explores different chunks of the model through an iterative CPU-GPU collaboration cycling between pruning, finetuning, and updating stages.
- Prune the model with Taylor-approximation-based importance estimation.
- Upload the pruned model to the GPU and finetune.
- Transfer the updated weights to the CPU.
2) Generate a finetuning mask controlled by Fisher information, which determines the subset of the pruned model that should be updated during finetuning.
- Prevents early saturation in the exploration process and further reduces the GPU memory footprint of gradients and optimizer states.
3. Methodology
- Limitations of pre-finetune pruning
- Prunes the full model, and then finetunes the pruned model on GPU.
- Suboptimal pruning: The inability to incorporate critical information related to the downsream task during pruning.
- Limitations of post-finetune pruning
- Conducts finetuning both before (to find optimal pruning params.) and after pruning.
- The full model is first finetuned on GPU followed by pruning → pruned model is finetuned again on GPU.
- RECAP
- Iteratively finetune various chunks of the full model to maintain high quality in the finetuned pruned model (= to find optimal pruning params.)
- Leverages the CPU to determine which part of the model to operate on and which weights to update at each iteration.
- RECAP Steps
- Inject task-specific heads into the pre-trained base model.
- Compute importance metrics for all model weights of pruned model at the CPU.
- Sample a tiny subset of the dataset and approximate how much the objective loss changes when removing weights using a first-order Taylor expansion.
- Determine which weights within the pruned model should be updated during 2nd finetuning.
- Use empirical Fisher so that weights with lower gradient values will not be updated.
- Load the pruned model and finetuning masks to GPU. Perform 2nd finetuning on the downstream task/dataset.
- Only update weights of the pre-computed fine-tuning mask. Update the full model weights at the CPU.
3.1. Pruning Stage
3.1.1. Grouping Model Weights
- Unstructured pruning: Pruning decision made for each weight. → Unstructured sparsity in weight matrices (masked version of the original) after pruning.
- Structured pruning: Model weights are grouped, and the same pruning decision is made for weights within each group. → Hardware-friendly
- Transformer weights
- Attention modules: Each group contains Q,K,V, output weight matrices corresponding to a head.
- Feedforward layers: Each group contains model weights corresponding to a hidden dimension.
3.1.2. Importance Estimation for Pruning
- Compute the importance of each weight group within the model. → Using Taylor expansion-based weight importance estimation approach.