Skip to main content Link Menu Expand (external link) Document Search Copy Copied

Overcoming Oscillations in Quantization-Aware Training

2024-06-07


Keywords: #QAT #weightdistribution


0. Abstract

  • This paper delves into the phenomenon of weight oscillations and show that it can lead to a significant accuracy degradation due to wrongly estimated BN statistics.
  • We propose two novel QAT algorithms to overcome oscillations during training:
    • Oscillation dampening
    • Iterative weight freezing

1. Introduction

  • Main Problem: When using STE for QAT, weights seemingly randomly oscillate between adjacent quantization levels leading to detrimental noise during the optimization process.
  • Problem of weight oscillation: They corrupt the estimated inference statistics of the BN layer collected during training, leading to poor validation accuracy.
    • Pronounced in low-bit quantization of efficient networks with depth-wise separable layers.
    • Probable solution: BN statistics re-estimation
  • Our solution
    • Oscillation dampening
    • iterative freezing of oscillating weights

2. Oscillations in QAT

2.1. Quantization-aware training

  • Original weights: commonly referred to as the latent weights or shadow weights
  • STE: We approximate the gradient of the rounding operator as 1 within the quantization limits.

2.2. Oscillations

  • Toy regression example: Looks complicated, but really just the MSE between label and output of quantized model.

  • As the latent (shadow) weight $w$ approaches the optimal value $w_\ast$, it starts oscillating around the decision threshold between the quantization levels above $w_\uparrow$ and below $w_\downarrow$ the optimal value, as opposed to converging to the region closer to the optimal value q(w_\ast) $w_\ast$.

  • Why does this happen? When weights are above the threshold, STE pushes latent weight down towards $w_\downarrow$. When weights are below the threshold, STE push latent weight up towards $w_\uparrow$.
    • When input is fixed, quantized weight will induce positive/negative gradients regardless of whether optimal weight is above/below the threshold. Therefore, weights oscillate near the threshold rather than the optimal value.
  • Frequency of oscillation: Dictated by the distance of the optimal value from its closest quantization level
    • this can be explained.

2.3. Oscillations in practice

  • We observe that many of the weights appear to randomly oscillate between two adjacent quantization levels.
  • After the supposed convergence of the network, a large fraction of the latent weights lie right at the decision boundary between grid points.
    • This reinforces the observation that a significant proportion of weights oscillate and cannot converge.

2.3.1. The effect on BN statistics

  • Many weights oscillate between quantized states, even at the supposed convergence of the network → Output statistics of each layer can vary significantly between gradient updates
  • BN layers track the running mean/variance, so that it can be used during inference
    • These running estimates can be corrupted due to dist. shift by oscillations
  • Degradation in accuracy due to shift induced by oscillations is influnced by two factors:
    1. Lower the bit-width $b$, the distance becomes larger, as it is proportional to $1/2^b$.
    2. Number of weights per outpu channel. Smaller the number of weights, the larger the contribution of individual weights to the final accumulation. When number is big enough, the effects of oscillations average out due to the law of large number.
  • BN re-estimation: Is this why BN calibration shows way better accuracy than train-from-scratch?
  • KL divergence to quantify the discrepancy between population and estimated statistics
    • KL divergence is much larger for depth-wise separable layers than in point-wise convolutions.

2.3.2. The effect on training

??

4. Overcoming oscillations in QAT

4.1. Quantifying Oscillations

  • Calculate frequency of oscillations over time using an exponential moving average (EMA)
  • For oscillation to occur, it needs to satisfy:
    1. Integer value of the weight needs to change
    2. Direction of change needs to be opposite than that of the previous change.
  • We then track the frequency of oscillations over time using EMA:

4.2. Oscillation dampening

  • When weights oscillate, they always move around the decision threshold between two quantization bins.
  • This means that oscillating weights are always close to the edge of the quantization bin.
  • In order to dampen the oscillatory behavior, we employ a regularization term that encourages latent weights to be close to the center of the bin rather than its edge.
  • Dampening loss (similar to weight decay):
  • Final training objective is: $\mathcal{L}=\mathcal{L}\text{task}+\lambda\mathcal{L}\text{dampen}$