-
Notifications
You must be signed in to change notification settings - Fork 836
/
trt_calibrator.py
102 lines (79 loc) · 3.43 KB
/
trt_calibrator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# encoding: utf-8
"""
@author: xingyu liao
@contact: [email protected]
Create custom calibrator, use to calibrate int8 TensorRT model.
Need to override some methods of trt.IInt8EntropyCalibrator2, such as get_batch_size, get_batch,
read_calibration_cache, write_calibration_cache.
"""
# based on:
# https://github.com/qq995431104/Pytorch2TensorRT/blob/master/myCalibrator.py
import os
import sys
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
import torchvision.transforms as T
sys.path.append('../..')
from fastreid.data.build import _root
from fastreid.data.data_utils import read_image
from fastreid.data.datasets import DATASET_REGISTRY
import logging
from fastreid.data.transforms import ToTensor
logger = logging.getLogger('trt_export.calibrator')
class FeatEntropyCalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, args):
trt.IInt8EntropyCalibrator2.__init__(self)
self.cache_file = 'reid_feat.cache'
self.batch_size = args.batch_size
self.channel = args.channel
self.height = args.height
self.width = args.width
self.transform = T.Compose([
T.Resize((self.height, self.width), interpolation=3), # [h,w]
ToTensor(),
])
dataset = DATASET_REGISTRY.get(args.calib_data)(root=_root)
self._data_items = dataset.train + dataset.query + dataset.gallery
np.random.shuffle(self._data_items)
self.imgs = [item[0] for item in self._data_items]
self.batch_idx = 0
self.max_batch_idx = len(self.imgs) // self.batch_size
self.data_size = self.batch_size * self.channel * self.height * self.width * trt.float32.itemsize
self.device_input = cuda.mem_alloc(self.data_size)
def next_batch(self):
if self.batch_idx < self.max_batch_idx:
batch_files = self.imgs[self.batch_idx * self.batch_size:(self.batch_idx + 1) * self.batch_size]
batch_imgs = np.zeros((self.batch_size, self.channel, self.height, self.width),
dtype=np.float32)
for i, f in enumerate(batch_files):
img = read_image(f)
img = self.transform(img).numpy()
assert (img.nbytes == self.data_size // self.batch_size), 'not valid img!' + f
batch_imgs[i] = img
self.batch_idx += 1
logger.info("batch:[{}/{}]".format(self.batch_idx, self.max_batch_idx))
return np.ascontiguousarray(batch_imgs)
else:
return np.array([])
def get_batch_size(self):
return self.batch_size
def get_batch(self, names, p_str=None):
try:
batch_imgs = self.next_batch()
batch_imgs = batch_imgs.ravel()
if batch_imgs.size == 0 or batch_imgs.size != self.batch_size * self.channel * self.height * self.width:
return None
cuda.memcpy_htod(self.device_input, batch_imgs.astype(np.float32))
return [int(self.device_input)]
except:
return None
def read_calibration_cache(self):
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
return f.read()
def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
f.write(cache)