diff --git a/quantization/object_detection/trt/yolov3/data_reader.py b/quantization/object_detection/trt/yolov3/data_reader.py index 3d0a67511..2f8b8d034 100644 --- a/quantization/object_detection/trt/yolov3/data_reader.py +++ b/quantization/object_detection/trt/yolov3/data_reader.py @@ -97,7 +97,7 @@ def get_next(self): def load_serial(self): width = self.width height = self.height - nchw_data_list, filename_list, image_size_list = preprocess_func(self.image_folder, height, width, + nchw_data_list, filename_list, image_size_list = self.preprocess_func(self.image_folder, height, width, self.start_index, self.stride) input_name = self.input_name @@ -131,7 +131,7 @@ def load_batches(self): for index in range(0, stride, batch_size): start_index = self.start_index + index print("Load batch from index %s ..." % (str(start_index))) - nchw_data_list, filename_list, image_size_list = preprocess_func(self.image_folder, height, width, + nchw_data_list, filename_list, image_size_list = self.preprocess_func(self.image_folder, height, width, start_index, batch_size) if nchw_data_list.size == 0: @@ -225,7 +225,7 @@ def load_batches(self): for index in range(0, stride, batch_size): start_index = self.start_index + index print("Load batch from index %s ..." % (str(start_index))) - nchw_data_list, filename_list, image_size_list = preprocess_func( + nchw_data_list, filename_list, image_size_list = self.preprocess_func( self.image_folder, height, width, start_index, batch_size) if nchw_data_list.size == 0: diff --git a/quantization/object_detection/trt/yolov3/e2e_user_yolov3_example.py b/quantization/object_detection/trt/yolov3/e2e_user_yolov3_example.py index 0dc046b67..a8f5db19f 100644 --- a/quantization/object_detection/trt/yolov3/e2e_user_yolov3_example.py +++ b/quantization/object_detection/trt/yolov3/e2e_user_yolov3_example.py @@ -43,7 +43,7 @@ def get_calibration_table(model_path, augmented_model_path, calibration_dataset) # data_reader = YoloV3DataReader(calibration_dataset, stride=1000, batch_size=20, model_path=augmented_model_path) # calibrator.collect_data(data_reader) - write_calibration_table(calibrator.compute_range()) + write_calibration_table(calibrator.compute_data()) print('calibration table generated and saved.')