From 16cb20c151d2b0c51cd69b8c090ffa0544913588 Mon Sep 17 00:00:00 2001 From: pinto0309 Date: Mon, 11 Apr 2022 21:47:45 +0900 Subject: [PATCH] Support for onnx.ModelProto input, optional bug fixes --- README.md | 26 ++++++++++++++------------ sne4onnx/__init__.py | 2 +- sne4onnx/onnx_network_extraction.py | 20 +++++++++++--------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index ab6b2be..475b6a0 100644 --- a/README.md +++ b/README.md @@ -72,18 +72,15 @@ $ python Help on function extraction in module sne4onnx.onnx_network_extraction: extraction( - input_onnx_file_path: str, input_op_names: List[str], output_op_names: List[str], - output_onnx_file_path: Union[str, NoneType] = '', - onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None + input_onnx_file_path: Union[str, NoneType] = '', + onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None, + output_onnx_file_path: Union[str, NoneType] = '' ) -> onnx.onnx_ml_pb2.ModelProto Parameters ---------- - input_onnx_file_path: str - Input onnx file path. - input_op_names: List[str] List of OP names to specify for the input layer of the model. Specify the name of the OP, separated by commas. @@ -94,16 +91,21 @@ extraction( Specify the name of the OP, separated by commas. e.g. ['ddd','eee','fff'] - output_onnx_file_path: Optional[str] - Output onnx file path. - If not specified, .onnx is not output. - Default: '' + input_onnx_file_path: Optional[str] + Input onnx file path. + Either input_onnx_file_path or onnx_graph must be specified. + onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph. onnx_graph: Optional[onnx.ModelProto] onnx.ModelProto. Either input_onnx_file_path or onnx_graph must be specified. onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph. + output_onnx_file_path: Optional[str] + Output onnx file path. + If not specified, .onnx is not output. + Default: '' + Returns ------- extracted_graph: onnx.ModelProto @@ -125,9 +127,9 @@ $ sne4onnx \ from sne4onnx import extraction extracted_graph = extraction( - input_onnx_file_path='input.onnx', input_op_names=['aaa', 'bbb', 'ccc'], output_op_names=['ddd', 'eee', 'fff'], + input_onnx_file_path='input.onnx', output_onnx_file_path='output.onnx', ) ``` @@ -138,8 +140,8 @@ from sne4onnx import extraction extracted_graph = extraction( input_op_names=['aaa', 'bbb', 'ccc'], output_op_names=['ddd', 'eee', 'fff'], - output_onnx_file_path='output.onnx', onnx_graph=graph, + output_onnx_file_path='output.onnx', ) ``` diff --git a/sne4onnx/__init__.py b/sne4onnx/__init__.py index 930cd4a..7283581 100644 --- a/sne4onnx/__init__.py +++ b/sne4onnx/__init__.py @@ -1,3 +1,3 @@ from sne4onnx.onnx_network_extraction import extraction, main -__version__ = '1.0.4' +__version__ = '1.0.5' diff --git a/sne4onnx/onnx_network_extraction.py b/sne4onnx/onnx_network_extraction.py index 7ce5e92..546a4d9 100644 --- a/sne4onnx/onnx_network_extraction.py +++ b/sne4onnx/onnx_network_extraction.py @@ -32,19 +32,16 @@ class Color: def extraction( - input_onnx_file_path: str, input_op_names: List[str], output_op_names: List[str], - output_onnx_file_path: Optional[str] = '', + input_onnx_file_path: Optional[str] = '', onnx_graph: Optional[onnx.ModelProto] = None, + output_onnx_file_path: Optional[str] = '', ) -> onnx.ModelProto: """ Parameters ---------- - input_onnx_file_path: str - Input onnx file path. - input_op_names: List[str] List of OP names to specify for the input layer of the model.\n\ Specify the name of the OP, separated by commas.\n\ @@ -55,16 +52,21 @@ def extraction( Specify the name of the OP, separated by commas.\n\ e.g. ['ddd','eee','fff'] - output_onnx_file_path: Optional[str] - Output onnx file path.\n\ - If not specified, .onnx is not output.\n\ - Default: '' + input_onnx_file_path: Optional[str] + Input onnx file path.\n\ + Either input_onnx_file_path or onnx_graph must be specified.\n\ + onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph. onnx_graph: Optional[onnx.ModelProto] onnx.ModelProto.\n\ Either input_onnx_file_path or onnx_graph must be specified.\n\ onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph. + output_onnx_file_path: Optional[str] + Output onnx file path.\n\ + If not specified, .onnx is not output.\n\ + Default: '' + Returns ------- extracted_graph: onnx.ModelProto