-
Notifications
You must be signed in to change notification settings - Fork 836
/
onnx_export.py
167 lines (135 loc) · 4.6 KB
/
onnx_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# encoding: utf-8
"""
@author: xingyu liao
@contact: [email protected]
"""
import logging
import os
import argparse
import io
import sys
import onnx
import onnxoptimizer
import torch
from onnxsim import simplify
from torch.onnx import OperatorExportTypes
sys.path.append('.')
from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import build_model
from fastreid.utils.file_io import PathManager
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger
# import some modules added in project like this below
# sys.path.append("projects/FastDistill")
# from fastdistill import *
setup_logger(name="fastreid")
logger = logging.getLogger("fastreid.onnx_export")
def setup_cfg(args):
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def get_parser():
parser = argparse.ArgumentParser(description="Convert Pytorch to ONNX model")
parser.add_argument(
"--config-file",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--name",
default="baseline",
help="name for converted model"
)
parser.add_argument(
"--output",
default='onnx_model',
help='path to save converted onnx model'
)
parser.add_argument(
'--batch-size',
default=1,
type=int,
help="the maximum batch size of onnx runtime"
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=[],
nargs=argparse.REMAINDER,
)
return parser
def remove_initializer_from_input(model):
if model.ir_version < 4:
print(
'Model with ir_version below 4 requires to include initilizer in graph input'
)
return
inputs = model.graph.input
name_to_input = {}
for input in inputs:
name_to_input[input.name] = input
for initializer in model.graph.initializer:
if initializer.name in name_to_input:
inputs.remove(name_to_input[initializer.name])
return model
def export_onnx_model(model, inputs):
"""
Trace and export a model to onnx format.
Args:
model (nn.Module):
inputs (torch.Tensor): the model will be called by `model(*inputs)`
Returns:
an onnx model
"""
assert isinstance(model, torch.nn.Module)
# make sure all modules are in eval mode, onnx may change the training state
# of the module if the states are not consistent
def _check_eval(module):
assert not module.training
model.apply(_check_eval)
logger.info("Beginning ONNX file converting")
# Export the model to ONNX
with torch.no_grad():
with io.BytesIO() as f:
torch.onnx.export(
model,
inputs,
f,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
# verbose=True, # NOTE: uncomment this for debugging
# export_params=True,
)
onnx_model = onnx.load_from_string(f.getvalue())
logger.info("Completed convert of ONNX model")
# Apply ONNX's Optimization
logger.info("Beginning ONNX model path optimization")
all_passes = onnxoptimizer.get_available_passes()
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
assert all(p in all_passes for p in passes)
onnx_model = onnxoptimizer.optimize(onnx_model, passes)
logger.info("Completed ONNX model path optimization")
return onnx_model
if __name__ == '__main__':
args = get_parser().parse_args()
cfg = setup_cfg(args)
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
if cfg.MODEL.HEADS.POOL_LAYER == 'FastGlobalAvgPool':
cfg.MODEL.HEADS.POOL_LAYER = 'GlobalAvgPool'
model = build_model(cfg)
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
if hasattr(model.backbone, 'deploy'):
model.backbone.deploy(True)
model.eval()
logger.info(model)
inputs = torch.randn(args.batch_size, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(model.device)
onnx_model = export_onnx_model(model, inputs)
model_simp, check = simplify(onnx_model)
model_simp = remove_initializer_from_input(model_simp)
assert check, "Simplified ONNX model could not be validated"
PathManager.mkdirs(args.output)
save_path = os.path.join(args.output, args.name+'.onnx')
onnx.save_model(model_simp, save_path)
logger.info("ONNX model file has already saved to {}!".format(save_path))