Skip to main content Link Menu Expand (external link) Document Search Copy Copied

Resource-Efficient Transformer Pruning for Finetuning of Large Models

2024-07-22



1. Introduction

Intro

  • LLMs are pre-trained over very large datasets, and these foundations models (PLMs, pre-trained models) are finetuned further on a downstream task/dataset to maximize task-specific performance upon deployment.

Problem Statement

  • Over-parameterization and using big datasets are known to result in better generalization performance when training PLMs. → Very big PLMs that require hundreds of GBs of GPU memory for finetuning and inference.
  • Effectively fineutning a PLM with limited resources while maintaining performance remains an open challenge.

  • Pruning
    • Based on the argument that not all weights of PLM are necessary for a given downstream task.
    • Representative pruning techniques- or post-finetune pruning, performs better because pre-finetune pruning may yield suboptimal results by removing important parameters for the finetuning task.
    • Many existing approaches still need some form of finetuning.

Contribution and Scope: RECAP

1) Explores different chunks of the model through an iterative CPU-GPU collaboration cycling between pruning, finetuning, and updating stages.

  • Prune the model with Taylor-approximation-based importance estimation.
  • Upload the pruned model to the GPU and finetune.
  • Transfer the updated weights to the CPU.

2) Generate a finetuning mask controlled by Fisher information, which determines the subset of the pruned model that should be updated during finetuning.

  • Prevents early saturation in the exploration process and further reduces the GPU memory footprint of gradients and optimizer states.

3. Methodology

  • Limitations of pre-finetune pruning
    • Prunes the full model, and then finetunes the pruned model on GPU.
    • Suboptimal pruning: The inability to incorporate critical information related to the downsream task during pruning.
  • Limitations of post-finetune pruning
    • Conducts finetuning both before (to find optimal pruning params.) and after pruning.
    • The full model is first finetuned on GPU followed by pruning → pruned model is finetuned again on GPU.
  • RECAP
    • Iteratively finetune various chunks of the full model to maintain high quality in the finetuned pruned model (= to find optimal pruning params.)
    • Leverages the CPU to determine which part of the model to operate on and which weights to update at each iteration.
  • RECAP Steps
    1. Inject task-specific heads into the pre-trained base model.
    2. Compute importance metrics for all model weights of pruned model at the CPU.
      • Sample a tiny subset of the dataset and approximate how much the objective loss changes when removing weights using a first-order Taylor expansion.
    3. Determine which weights within the pruned model should be updated during 2nd finetuning.
      • Use empirical Fisher so that weights with lower gradient values will not be updated.
    4. Load the pruned model and finetuning masks to GPU. Perform 2nd finetuning on the downstream task/dataset.
    5. Only update weights of the pre-computed fine-tuning mask. Update the full model weights at the CPU.

3.1. Pruning Stage

3.1.1. Grouping Model Weights

  1. Unstructured pruning: Pruning decision made for each weight. → Unstructured sparsity in weight matrices (masked version of the original) after pruning.
  2. Structured pruning: Model weights are grouped, and the same pruning decision is made for weights within each group. → Hardware-friendly
  • Transformer weights
    • Attention modules: Each group contains Q,K,V, output weight matrices corresponding to a head.
    • Feedforward layers: Each group contains model weights corresponding to a hidden dimension.

3.1.2. Importance Estimation for Pruning

  • Compute the importance of each weight group within the model. → Using Taylor expansion-based weight importance estimation approach.

3.2. Finetuning Stage