Skip to main content Link Menu Expand (external link) Document Search Copy Copied

Learned Step Size Quantization

2024-01-09


Keywords: #Quantization


0. Abstract

  • Proposal
    1. Builds upon existing methods for learning weights in quantized networks by improving how the quantizer itself is configured
    2. Introduce a novel means to estimate the task loss gradient at each weight and activation layer’s quantizer step size, such that it can be learned with other parameters.

1. Introduction

  • Learned Step Size Quantization (LSQ): A new way to learn the quantization mapping for each layer in a deep network

2. Methods

  • Given data to quantize $v$, quantizer step size $s$, the number of positive and negative quantization levels $Q_P$ and $Q_N$, we define the below:
    • A quantizer that computes $\bar{v}$; a quantized and integer scaled representation of the data
    • A quantized representation of the data at the same scale as $v$
  • Given an encoding with $b$ bits, for unsigned data $Q_N=0$ and $Q_P = 2^b-1$ and for signed data $Q_N=2^{b-1}$ and $Q_P = 2^{b-1}-1$
  • For inference, $\bar{w}$ and $\bar{x}$ can be used as input to integer MM in CONV or FC, and the output of such layers then rescaled by the step size using scalr-tensor multiplication.

2.1 Step Size Gradient

  • LSQ learns $s$ based on the training loss by introducing the following gradient through the quantizer to the step size parameter
  • The gradient is derived by using the straight through estimator(STE)

  • For QIL and PACT, the relative proximity of $v$ to the transition point between quantized states does not impact the gradient to the quantization parameters.
  • However, the closer a given $v$ is to a transition point, the more likely it is to change its quantizationi bin($\bar{v}$) as a result of a learned update to $s$ (since a smaller change in $s$ is required), thereby resulting in a large jump in $\hat{v}$
  • Thus, we would expect $\partial \hat{v}/\partial s$ to increase as the distance from $v$ to a transitioni point decreases.
  • Each layer of weights and each layer of activations has a distinct step size

2.2 Step Size Gradient Scale

  • (You et al., 2017) Good convergence is achieved when the ratio of average update magnitude (lr) to average parameter magnitude (param lr) is approximately the same for all weight layers.
  • Thus, for a network trained on some loss function $L$, the ratio below should on average be near 1.
  • $\nabla_s L$ : learning rate for step size($s$)
  • $L$: step size loss

  • We multiply step size loss by a gradient scale $g=1/\sqrt{N_WQ_P}$ mainly because:
    1. Step size parameter should be smaller as precision increases ← data is quantized more finely
    2. Step size updates should be larger as the number of quantized items increases ← more items are summed across when computing its gradient

2.3 Training

  • Quantizers are trained with LSQ by making their step sizes learnable parameters with loss gradient computed using the quantizer gradient above.
  • (Courbariaux et al., 2015) Training quantized networks: Full precision weights are stored and updated, quantized weights and activations are used for forward and backward passes, the gradient through the quantizer round function is computed using the STE (Bengoi et al., 2013) such that
  • Stochastic gradient descent

  • We set input activations and weights to $\hat{a}$ and $\hat{w}$ except the first and last, which is always 8-bit → Making first and last layers high precision has become standard practice for quantized networks
  • All other parameters are fp32
  • All quantized networks are initialized using weights from a trained full precision model with equivalent architecture before fine-tuning in the quantized space (PTQ)
  • Cosine learning rate decay without restarts (Loshchilov & Hutter, 2016)
  • Under the assumption that the optimal solution for 8-bit networks is close to the full precision solution (McKinstry et al., 2018), 8-bit networks were trained for 1 epoch while all other networks were trained for 90 epochs.

3. Results

3.1 Weight Decay

  • Reducing model precision reduces a model’s tendency to overfit, and thus reduce the need of regularization in the form of weight decay to achieve better performance
  • Lower precision networks reach higher accuracy with less weight decay

3.2 Comparison with Other Approaches

  • In some cases, we report slightly higher accuracy on full precision networks than in their original publications, which we attribute to our use of cosine learning rate decay (Loshchilov & Hutter, 2016).