Quantization Speedup

class nni.compression.pytorch.quantization_speedup.ModelSpeedupTensorRT(model, input_shape, config=None, onnx_path='default_model.onnx', extra_layer_bits=32, strict_datatype=True, calibrate_type=tensorrt.CalibrationAlgoType.ENTROPY_CALIBRATION_2, calib_data_loader=None, calibration_cache='calibration.cache', batchsize=1, input_names=['actual_input_1'], output_names=['output1'])[source]
  • model (pytorch model) – The model to speedup by quantization.

  • input_shape (tuple) – The input shape of model, shall pass it to torch.onnx.export.

  • config (dict) – Config recording bits number and name of layers.

  • onnx_path (str) – The path user want to store onnx model which is converted from pytorch model.

  • extra_layer_bits (int) – Other layers which are not in config will be quantized to corresponding bits number.

  • strict_datatype (bool) – Whether constrain layer bits to the number given in config or not. If true, all the layer will be set to given bits strictly. Otherwise, these layers will be set automatically by tensorrt.

  • calibrate_type (tensorrt.tensorrt.CalibrationAlgoType) – The algorithm of calibrating. Please refer to https://docs.nvidia.com/deeplearning/ tensorrt/api/python_api/infer/Int8/Calibrator.html for detail

  • calibrate_data (numpy array) – The data using to calibrate quantization model

  • calibration_cache (str) – The path user want to store calibrate cache file

  • batchsize (int) – The batch size of calibration and inference

  • input_names (list) – Input name of onnx model providing for torch.onnx.export to generate onnx model

  • output_name (list) – Output name of onnx model providing for torch.onnx.export to generate onnx model


Get onnx config and build tensorrt engine.


Export TensorRT quantized model engine which only can be loaded by TensorRT deserialize API.


path (str) – The path of export model


Do inference by tensorrt builded engine.


test_data (pytorch tensor) – Model input tensor


Load TensorRT quantized model engine from specific path.


path (str) – The path of export model