Artificial Intelligence (AI)
Discuss current events in AI and technological innovations with Intel® employees
489 Discussions

Automated Mixed-Precision Quantization for Next-Generation Inference Hardware

MaryT_Intel
Employee
0 0 6,700

Key Takeaways

  • Learn about quantization to low bit-widths for model compression and performance speed-up on next-generation hardware

Published: March 31, 2021

 

Overview


Compute requirements at the edge, such as the resource footprint of endpoints/edge devices & performance KPIs of Internet of Things applications make optimization of Deep Neural Networks (DNNs) critical for deployment. By deployment we assume that the Deep Learning (DL) model gets fixed in all aspects (e.g., task definition, dataset, training pipeline, etc.) so it can speed up model inferencing presuming the rest of the flow is equal.
 
There are two main objectives for model optimization: reduction of the model size and decreasing computational complexity and both are highly correlated with each other. For example, reduction of the model size in many cases leads to less memory consumption and thus latency reduction, which, in turn, improves performance. Nowadays, many optimization methods have been proposed to tackle both objectives. Perhaps, uniform quantization is one of the simple methods to do that because it leads to a small accuracy drop, substantial reduction in the model size, and significant performance speedup can be applied even without retraining or fine-tuning the optimizing model.
 
The essence of uniform quantization is an approximation of original floating-point operations inside the network by their integer analogs. In case of 8-bit uniform quantization, it can be described as follows:

conv(wfp32,afp32) = ∑wfp32 * afp32 + bfp32 ≈ ∑js* (wi8 - zpw) * s* (ai8 - zpa) + s* s* bi32 = s* sa (∑j (wi8 - zpw) * (ai8 - zpa) + bi32) where:

sw - floating-point weight scale, zpw - 8-bit weight zero-point

sa - floating-point activation scale, zpa - 8-bit activation zero-point

Thus, the most computationally expensive part (wi8 - zpw) * (ai8-zpa) can be executed in low precision. For more details about quantization-aware training for DNNs, we suggest referring to this paper.

Recently, we wrote about quantization capabilities of Neural Network Compression Framework (NNCF). The framework is designed to perform various Deep Learning optimization methods with fine-tuning specifically for Intel hardware and highly aligned with Intel® Distribution of OpenVINO™ toolkit. In this article, we provide a more holistic view on the mixed-precision quantization methods, i.e., those that allow quantizing different operations in the model into different precisions. 
 

Mixed-precision quantization methods


Mixed-precision quantization is a promising method to get an additional speedup and model size reduction comparing to uniform 8-bits quantization. It can naturally exploit model redundancy by assigning lower bit-width to the less useful or insensitive layers in the model. The only problem is how to measure the sensitivity score and map it directly to the bit-width. Different methods propose different ways to tackle this problem. We selected the two most famous of them and implemented in the NNCF:
 

  • HAWQ-based mixed-precision method
     
  • AutoQ – RL-based approach to the precision selection

It is worth noting that the term “mixed-precision” can mean different level of mixing precisions in the model. For example, we can admit different layers having different precisions so that each layer will have one precision for all its inputs, or we can go further and allow different inputs of operation having different precisions. Fig. 1 shows examples of possible bit-width granularity. 
 

figure 1

Figure 1. Examples of different bit assignment strategies: left – unified, right – non-unified


From the model inference standpoint, the first case of uniform precision for all operation inputs is clear – the operation can be uniquely treated as a low-precision operation in the specific bit-width. In the second case of different number of bits for different inputs there are two options for the runtime:
 

  • Select the highest bit-width at the model loading step and upconvert all that are lower to the selected one. Thus, the only benefit we get from lower bit-widths is the reduction of the model size, for example, if the bit-width for weights is lower than for activations. 
     
  • Let all inputs keep the defined bit-width and to be quantized to it at inference time but upconvert all bit-width to the highest one when executing the operation in low-precision after memory prefetch. This obviously requires HW the support, but it can bring additional performance improvement due to latency reduction when transferring data to local memory.

The mixed-precision methods implemented in the NNCF also offer different levels of such bit-width assignment granularity. Let’s look at them in detail.

HAWQ


NNCF employs the HAWQ-v2 method to automatically find “optimal” mixed-precision configuration by considering the sensitivity of each layer, i.e., how much lower-bit quantization of each layer decreases the loss of the model. The most sensitive layers are kept at higher precision. The sensitivity of the i-th layer is calculated by multiplying the average Hessian trace of the loss w.r.t to model parameters with the L2 norm of quantization perturbation:
 

formula 1

The sum of the sensitivities for each layer forms a metric which serves as a proxy to the accuracy of the compressed model: the lower the metric, the more accurate should be the corresponding mixed precision model on the validation dataset. To find the optimal trade-off between accuracy and performance of the mixed precision model we also compute a compression ratio - the ratio between bit complexity of a fully INT8 model and mixed-precision lower bit-width one. The bit complexity of the model is a sum of bit complexities for each quantized layer, which are defined as a product of the layer FLOPS and the quantization bit-width. The optimal configuration is found by calculating the sensitivity metric and the compression ratio for all possible bit-width settings and selecting the one with the minimal metric value among all configurations with a compression ratio below the specified threshold. It is assumed that the compression ratio is defined by the user. We set it to 1.5 by default, which, in general, should be enough to compress the model while staying within 1% of accuracy drop.

The exhaustive search of all possible bit-width arrangements is not feasible since it has an exponential time complexity: BL, where B is the number of quantization precision options, L is the number of layers in the model. The complexity of the problem is vastly reduced by using the information about sensitivity of each layer:  layers with a smaller average Hessian trace value are quantized to  lower bit-width and vice versa. Thus, the search space for MobileNet-v2 is decreased from 353 ≈ 1.9 * 1025 to 1.4 * 104, assuming there're three options for bit-width  -  2, 4, 8.The Hessian trace is estimated with the randomized Hutchinson algorithm. Given a random vector v sampled from a standard normal distribution, the trace of symmetric matrix H is equal to the estimation of a quadratic form:

formula 2

The randomized algorithm solves the expectation by Monte Carlo using sampling of v from its distribution, evaluating the quadratic term, and averaging:

formula 3

Evaluation of the quadratic term is done by computing  Hv - the result of multiplication of the Hessian matrix with a given random vector v, without the explicit formation of the Hessian operator. For gradient of the loss with respect to the i-th block Gi  a random vector v, which is independent of Wi, we have the equation:

formula 4

where Hi is the Hessian matrix of loss with respect to Wi. Hence Hv can be computed by 2 backpropagation passes: first - with respect to the loss and second - with respect to the product of the gradients and a random vector.   

AutoQ


Different model architectures possess varying redundancy and sensitivity towards accuracy. Different device architectures can provide different performance capabilities. A mixed-precision model performs efficiently on one device, but may behave differently on another device. It is a scaling and productivity challenge to find respective optimal precision configuration for a range of platforms (CPU, GPU, etc.). To this end, NNCF provides an alternate mode for mixed-precision automation, namely AutoQ in NNCF. Based on HAQ, AutoQ employs deep reinforcement learning algorithm to automate the explorations of device-specific precision space and eventually learn the quantization policy that achieves optimal trade-off between accuracy, performance and user-specified constraints (e.g., model size, latency, compute complexity, power).

figure 2

AutoQ utilizes an actor-critic algorithm, Deep Deterministic Policy Gradient (DDPG) for efficient search over the bit width space. DDPG is trained in an episodic fashion, converging to a deterministic mixed-precision policy after several episodes. An episode is constituted by stepping, the DDPG transitions from quantizer to quantizer sequentially to predict a precision of a layer. Each quantizer essentially denotes a state of RL framework and it is represented by attributes of the associated layers. For example, a quantizer for 2D Convolution is represented by its quantizer Id (integer), input and output channel size, feature map dimension, kernel patch, stride size, a boolean indicating a depthwise convolution, number of parameters and previous quantizer action. It is recommended to check out the implementation for the detailed featurization of different layer types. 

When the agent enters a state/quantizer, it receives the state features and forward passes them through its network. The output of the forward pass is a scalar continuous action output, which is subsequently mapped to the bit width options of the current quantizer. The episode terminates after the prediction of the last quantizer, and a complete layer-wise mixed-precision policy is obtained. To ensure a policy fits in the user-specified compression ratio, the policy is post processed by reducing the precision sequentially from the last quantizer until the compression ratio is met.

To evaluate the goodness of a mixed-precision policy, NNCF backend quantizes the workload accordingly and performs evaluation with the user-registered function. The evaluated score, together with the state embedding, predicted action are appended to an experience vault to serve for DDPG learning. The learning is carried out by sampling the data point from the experience vault for supervised training of the DDPG network. This process typically happens at a fixed interval. In the current implementation, it is performed after each episode evaluation. For bootstrapping, noise is added to action output as a way of exploration and increasing the diversity of experience. As the episodic iterations progress, the noise magnitude is gradually reduced to zero, a deterministic mixed-precision policy is said to be converged at the end of the iterations. NNCF currently keeps track of the best policy and uses it for fine tuning.


Use Mixed-precision Quantization form the NNCF


From the usage standpoint, we suggest following the recommendations below:
 

  • Use HAWQ method if you need to get fast and reproducible results for INT4-INT8 quantization. 
     
  • Use AutoQ in case if you need to achieve a higher compression ratio. This method can potentially provide more accurate results due to its brute force nature.

Both methods require some modifications of the training code to be enabled. Mixed-precision quantization in NNCF can be deemed as an extension to NNCF uniform quantization. For beginners, we recommend referring the previous blog and documentation on how to adapt PyTorch training script for uniform quantization and fine-tuning. We discuss the steps to integrate NNCF mixed-precision flow as follows.

In NNCF, HAWQ and AutoQ are considered as precision initialization as it determines the bit width configuration for the downstream fine-tuning. Hence, to enable HAWQ or AutoQ mode, users can specify through the precision key in the “initializer” section of NNCF config. 

figure 3

Common parameters: target_device specifies hardware awareness of NNCF, it determines the bit width choices available for a particular layer w.r.t to hardware. bits field also defines the precision space of quantizer, but it is only active in the absence of target device.

HAWQ Config: compression_ratio is the ratio between bit complexity of a fully INT8 model and mixed-precision lower bit width one. For example, uniformly 8-bit quantized model has compression ratio equals to 1, uniform 4-bit quantization - equals to 2. By default, the compression ratio is 1.5 which should be enough to compress the model with no more than 1% of accuracy drop in general case. But if it doesn't happen, the lower ratio can be set.

AutoQ Config: iter_number, compression_ratio are two parameters, which may require user to experiment. iter_number is the number of episodes of AutoML optimization. compression_ratio is the target model size after quantization, relative to total parameters size in FP32. For example, uniformly 8-bit quantized model is 0.25 in compression ratio, 0.125 for uniform 4-bit quantization. A good choice of iter_number, compression_ratio depends on the number of quantizers in a workload and the number of bit width choices. More quantizable layers generally requires more episodes to necessitate exploration on large solution space. Our intention is to have user to tune only two intuitive parameters, internally, there are adaptive mechanism to modify the rest of the AutoML hyperparameters. eval_subset_ratio is ratio of dataset to be used for evaluation for each iteration. It is used by the callback function. (See next paragraph). Its primary purpose is to control the time taken per iteration, hence overall search time.

figure 4

Figure 3: Training Script Adaptation to enable NNCF Mixed-Precision Quantization


In conjunction with NNCF configuration, users are required to adapt the training script using NNCF API. The snippet above covers adaptation of both HAWQ and AutoQ for ImageNet Classification sample. Note that HAWQ and AutoQ adaptation are independent of each other, users can adapt either one of them. 

HAWQ requires users to register (L33-34) loss module and data loader for NNCF backend. It's used for Hessian trace calculation by backpropagation. Training data loader is more desirable because it shuffles data samples, which in turn can provide superior convergence of the algorithm. For some models, loss callback (L17-21) should be specified (L34) in case of complex output's post-processing or multiple losses, like in Inception-v3 in the code snippet.

AutoQ requires users to define a callback function, basically instructing AutoQ how to evaluate the impact of a mixed-precision policy. This function internally is called prior to DDPG learning to set a baseline. It is then called in each episode to assess agent’s prediction of layer-wise bit width. Typically, users can wrap existing evaluation function and register it for NNCF backend. The callback (L24-26) wraps validation function and utilizes Top5 accuracy as the objective metric. During the registration (L36-37), users provide the callback reference and data loader. The callback, data loader and registration mechanism are necessary because dataset and objective could differ in meaning and implementation for a different workload. For AutoQ data loader, users are free to reuse the original training, validation, or test data loader. If needed, users can create a custom data loader for tailored dataset. 
That is all! Users can launch the NNCF-adapted PyTorch script to quantize the model in mixed precision. For out-of-box sample, here is the ImageNet classification script, working correspondingly with either HAWQ or AutoQ config.  

Results


To demonstrate the optimization capabilities of the implemented methods, we conducted several training experiments using NNCF training examples that allow training Computer Vision models for three different tasks: Image Classification, Object Detection and Semantic Segmentation. We experiment and tune the HAWQ and AutoQ knobs until accuracy deviation less than 1% (common criteria for compression). We used the same fine-tuning schedule for both AutoQ and HAWQ methods for a fair comparison. All the experiments were obtained with PyTorch framework.  

Table 1. shows the results of applying the methods to popular DL models. Apart from post-compression accuracy, we tabulate the model size and theoretical complexity relative to uniform 8-bit quantization. Since PyTorch does not support model compression capabilities, the Model Size Ratio is provided for the models converted to OpenVINO™ Intermediate Representation (IR) which saves weights of quantized convolutional and Fully-Connected layers in INT8 or INT4 precisions. Bit complexity is defined as the number of Multiply-Add-Accumulation (MAC) operation multiplied by the realized precision on device. For instance, the precision of convolution in Figure 1(b) is 4-bit as weight will be upcasted to 4-bit.

In all cases, both methods achieve a significant compression ratio both in model size and bit complexity while satisfying accuracy degradation criteria. It is worth noting that NNCF hides the complex implementation of mixed-precision quantization and algorithm from users, model deployment with HAWQ or AutoQ can be a matter of few iterations of tweaking the configurable knobs and fine-tuning schedule.
 

table 1

Table1: Compression results for HAWQ and AutoQ methods obtained with NNCF. The model size ratio is calculated for mixed-precision model converted to OpenVINO™ IR with respect to full INT8 model.


Conclusion


We introduced two mixed-precision quantization methods in the NNCF framework, namely HAWQ and AutoQ, that are aimed at improving compression capabilities for the models represented as OpenVINO™ IR. The methods are complementary to each other and can handle different use case scenarios. HAWQ can be used to quickly get a mixed-precision model while AutoQ method allows doing a comprehensive search over a variety of different mixed-precision configurations delivering potentially better results but at some time costs. Both methods are designed to be easily integrated into the custom training code outside of NNCF.  As for the model inference acceleration, we keep this topic for the further exploration when the HW that supports the corresponding mixed-precision execution is available.

If you have any ideas on how we can improve the product, we welcome contributions to the open-sourced NNCF framework. 


 

Notices & Disclaimers

Performance varies by use, configuration and other factors. Learn more at www.Intel.com/PerformanceIndex .  

Performance results are based on testing as of dates shown in configurations and may not reflect all publicly available updates. See backup for configuration details. No product or component can be absolutely secure. 

Your costs and results may vary. 

Intel technologies may require enabled hardware, software or service activation.

Intel disclaims all express and implied warranties, including without limitation, the implied warranties of merchantability, fitness for a particular purpose, and non-infringement, as well as any warranty arising from course of performance, course of dealing, or usage in trade.

© Intel Corporation.  Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others.  

© Intel Corporation.  Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others.
 

About the Author
Mary is the Community Manager for this site. She likes to bike, and do college and career coaching for high school students in her spare time.