Quantization Speedup

class nni.compression.pytorch.quantization_speedup.ModelSpeedupTensorRT(model, input_shape, config=None, onnx_path='default_model.onnx', extra_layer_bits=32, strict_datatype=True, calibrate_type=tensorrt.CalibrationAlgoType.ENTROPY_CALIBRATION_2, calib_data_loader=None, calibration_cache='calibration.cache', batchsize=1, input_names=['actual_input_1'], output_names=['output1'])[源代码]
参数
  • model (pytorch model) -- The model to speedup by quantization.

  • input_shape (tuple) -- The input shape of model, shall pass it to torch.onnx.export.

  • config (dict) -- Config recording bits number and name of layers.

  • onnx_path (str) -- The path user want to store onnx model which is converted from pytorch model.

  • extra_layer_bits (int) -- Other layers which are not in config will be quantized to corresponding bits number.

  • strict_datatype (bool) -- Whether constrain layer bits to the number given in config or not. If true, all the layer will be set to given bits strictly. Otherwise, these layers will be set automatically by tensorrt.

  • calibrate_type (tensorrt.tensorrt.CalibrationAlgoType) -- The algorithm of calibrating. Please refer to https://docs.nvidia.com/deeplearning/ tensorrt/api/python_api/infer/Int8/Calibrator.html for detail

  • calibrate_data (numpy array) -- The data using to calibrate quantization model

  • calibration_cache (str) -- The path user want to store calibrate cache file

  • batchsize (int) -- The batch size of calibration and inference

  • input_names (list) -- Input name of onnx model providing for torch.onnx.export to generate onnx model

  • output_name (list) -- Output name of onnx model providing for torch.onnx.export to generate onnx model

compress()[源代码]

Get onnx config and build tensorrt engine.

export_quantized_model(path)[源代码]

Export TensorRT quantized model engine which only can be loaded by TensorRT deserialize API.

参数

path (str) -- The path of export model

inference(test_data)[源代码]

Do inference by tensorrt builded engine.

参数

test_data (pytorch tensor) -- Model input tensor

load_quantized_model(path)[源代码]

Load TensorRT quantized model engine from specific path.

参数

path (str) -- The path of export model