Skip to content

Commit

Permalink
Replace the remaining target onnx version checking by the opset numbe…
Browse files Browse the repository at this point in the history
…r. (#187)

* replace the target onnx version by the opset number.

* fix for the tests.

* Remove the targeted_onnx from the parse function.

* one more fix.

* unify the opset comparison.
  • Loading branch information
wenbingl authored Nov 28, 2018
1 parent cb365db commit 93b1789
Show file tree
Hide file tree
Showing 35 changed files with 98 additions and 147 deletions.
4 changes: 1 addition & 3 deletions onnxmltools/convert/common/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class ModelComponentContainer(ModelContainer):
encapsulated in a ONNX ModelProto.
'''

def __init__(self, target_opset, targeted_onnx):
def __init__(self, target_opset):
'''
:param target_opset: number, for example, 7 for ONNX 1.2, and 8 for ONNX 1.3.
:param targeted_onnx: A string, for example, '1.1.2' and '1.2'.
Expand All @@ -139,8 +139,6 @@ def __init__(self, target_opset, targeted_onnx):
self.node_domain_version_pair_sets = set()
# The targeted ONNX operator set (referred to as opset) that matches the ONNX version.
self.target_opset = target_opset
# The targeted ONNX version. All produced operators should be supported by the targeted ONNX version.
self.targeted_onnx_version = targeted_onnx

def _make_value_info(self, variable):
value_info = helper.ValueInfoProto()
Expand Down
33 changes: 16 additions & 17 deletions onnxmltools/convert/common/_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import warnings
from logging import getLogger
from distutils.version import StrictVersion
from ...proto import onnx
from ...proto import helper
from ...proto import get_opset_number_from_onnx
Expand All @@ -15,7 +16,6 @@
from . import utils
from .data_types import *
from ._container import ModelComponentContainer
from .utils import compare_strict_version
from .optimizer import optimize_onnx
from .interface import OperatorBase

Expand Down Expand Up @@ -50,15 +50,15 @@ def full_name(self):

class Operator(OperatorBase):

def __init__(self, onnx_name, scope, type, raw_operator, targeted_onnx_version):
def __init__(self, onnx_name, scope, type, raw_operator, target_opset):
'''
:param onnx_name: A unique ID, which is a string
:param scope: The name of the scope where this operator is declared. It's a string.
:param type: A object which uniquely characterizes the type of this operator. For example, it can be a string,
pooling, if this operator is associated with a CoreML pooling layer.
:param raw_operator: The original operator which defines this operator; for example, a scikit-learn Imputer and
a CoreML Normalizer.
:param targeted_onnx_version: A StrictVersion object indicating the ONNX version used
:param target_opset: The target opset number for the converted model.
'''
self.onnx_name = onnx_name # operator name in the converted model
self.scope = scope
Expand All @@ -68,7 +68,7 @@ def __init__(self, onnx_name, scope, type, raw_operator, targeted_onnx_version):
self.outputs = []
self.is_evaluated = None
self.is_abandoned = False
self.targeted_onnx_version = targeted_onnx_version
self.target_opset = target_opset

@property
def full_name(self):
Expand Down Expand Up @@ -105,21 +105,20 @@ def infer_types(self):

class Scope:

def __init__(self, name, parent_scopes=None, variable_name_set=None, operator_name_set=None,
targeted_onnx_version=None):
def __init__(self, name, parent_scopes=None, variable_name_set=None, operator_name_set=None, target_opset=None):
'''
:param name: A string, the unique ID of this scope in a Topology object
:param parent_scopes: A list of Scope objects. The last element should be the direct parent scope (i.e., where
this scope is declared).
:param variable_name_set: A set of strings serving as the name pool of variables
:param operator_name_set: A set of strings serving as the name pool of operators
:param targeted_onnx_version: A StrictVersion object indicating the ONNX version used
:param target_opset: The target opset number for the converted model.
'''
self.name = name
self.parent_scopes = parent_scopes if parent_scopes else list()
self.onnx_variable_names = variable_name_set if variable_name_set is not None else set()
self.onnx_operator_names = operator_name_set if operator_name_set is not None else set()
self.targeted_onnx_version = targeted_onnx_version
self.target_opset = target_opset

# An one-to-many map from raw variable name to ONNX variable names. It looks like
# (key, value) = (raw_name, [onnx_name, onnx_name1, onnx_name2, ..., onnx_nameN])
Expand Down Expand Up @@ -210,7 +209,7 @@ def declare_local_operator(self, type, raw_model=None):
This function is used to declare new local operator.
'''
onnx_name = self.get_unique_operator_name(str(type))
operator = Operator(onnx_name, self.name, type, raw_model, self.targeted_onnx_version)
operator = Operator(onnx_name, self.name, type, raw_model, self.target_opset)
self.operators[onnx_name] = operator
return operator

Expand Down Expand Up @@ -238,7 +237,7 @@ def delete_local_variable(self, onnx_name):
class Topology:

def __init__(self, model, default_batch_size=1, initial_types=None,
reserved_variable_names=None, reserved_operator_names=None, target_opset=None, targeted_onnx=None,
reserved_variable_names=None, reserved_operator_names=None, target_opset=None,
custom_conversion_functions=None, custom_shape_calculators=None, metadata_props=None):
'''
Initialize a Topology object, which is an intermediate representation of a computational graph.
Expand All @@ -260,7 +259,7 @@ def __init__(self, model, default_batch_size=1, initial_types=None,
self.initial_types = initial_types if initial_types else list()
self.metadata_props = metadata_props if metadata_props else dict()
self.default_batch_size = default_batch_size
self.targeted_onnx_version = targeted_onnx
self.target_opset = target_opset
self.custom_conversion_functions = custom_conversion_functions if custom_conversion_functions else {}
self.custom_shape_calculators = custom_shape_calculators if custom_shape_calculators else {}

Expand Down Expand Up @@ -302,9 +301,9 @@ def _generate_unique_name(seed, existing_names):
def get_unique_scope_name(self, seed):
return Topology._generate_unique_name(seed, self.scope_names)

def declare_scope(self, seed, parent_scopes=list()):
scope = Scope(self.get_unique_scope_name(seed), parent_scopes, self.variable_name_set,
self.operator_name_set, self.targeted_onnx_version)
def declare_scope(self, seed, parent_scopes=None):
scope = Scope(self.get_unique_scope_name(seed), parent_scopes,
self.variable_name_set, self.operator_name_set, self.target_opset)
self.scopes.append(scope)
return scope

Expand Down Expand Up @@ -639,7 +638,7 @@ def convert_topology(topology, model_name, doc_string, target_opset, targeted_on
include '1.1.2', '1.2', and so on.
:return: a ONNX ModelProto
'''
if targeted_onnx is not None and compare_strict_version(targeted_onnx, onnx.__version__) != 0:
if targeted_onnx is not None and StrictVersion(targeted_onnx) != StrictVersion(onnx.__version__):
warnings.warn(
'targeted_onnx is deprecated, please specify target_opset for the target model.\n' +
'*** ONNX version conflict found. The installed version is %s while the targeted version is %s' % (
Expand All @@ -652,7 +651,7 @@ def convert_topology(topology, model_name, doc_string, target_opset, targeted_on

topology._initialize_graph_status_for_traversing()

container = ModelComponentContainer(target_opset, targeted_onnx)
container = ModelComponentContainer(target_opset)

# Put roots and leaves as ONNX's model into buffers. They will be added into ModelComponentContainer later.
tensor_inputs = {}
Expand Down Expand Up @@ -776,7 +775,7 @@ def convert_topology(topology, model_name, doc_string, target_opset, targeted_on
getLogger('onnxmltools').warning('The maximum opset needed by this model is only %d.' % op_version)

# Add extra information
add_metadata_props(onnx_model, topology.metadata_props)
add_metadata_props(onnx_model, topology.metadata_props, target_opset)
onnx_model.ir_version = onnx_proto.IR_VERSION
onnx_model.producer_name = utils.get_producer()
onnx_model.producer_version = utils.get_producer_version()
Expand Down
6 changes: 3 additions & 3 deletions onnxmltools/convert/common/shape_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import six
from ._registration import register_shape_calculator
from .data_types import Int64TensorType, FloatTensorType, StringTensorType, DictionaryType, SequenceType
from .utils import check_input_and_output_numbers, check_input_and_output_types, compare_strict_version
from .utils import check_input_and_output_numbers, check_input_and_output_types


def calculate_linear_classifier_output_shapes(operator):
Expand Down Expand Up @@ -39,7 +39,7 @@ def calculate_linear_classifier_output_shapes(operator):
operator.outputs[0].type = StringTensorType(shape=[N])
if len(class_labels) > 2 or operator.type != 'SklearnLinearSVC':
# For multi-class classifier, we produce a map for encoding the probabilities of all classes
if compare_strict_version(operator.targeted_onnx_version, '1.2') < 0:
if operator.target_opset < 7:
operator.outputs[1].type = DictionaryType(StringTensorType([1]), FloatTensorType([1]))
else:
operator.outputs[1].type = SequenceType(DictionaryType(StringTensorType([]), FloatTensorType([])), N)
Expand All @@ -50,7 +50,7 @@ def calculate_linear_classifier_output_shapes(operator):
operator.outputs[0].type = Int64TensorType(shape=[N])
if len(class_labels) > 2 or operator.type != 'SklearnLinearSVC':
# For multi-class classifier, we produce a map for encoding the probabilities of all classes
if compare_strict_version(operator.targeted_onnx_version, '1.2') < 0:
if operator.target_opset < 7:
operator.outputs[1].type = DictionaryType(Int64TensorType([1]), FloatTensorType([1]))
else:
operator.outputs[1].type = SequenceType(DictionaryType(Int64TensorType([]), FloatTensorType([])), N)
Expand Down
28 changes: 0 additions & 28 deletions onnxmltools/convert/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,31 +295,3 @@ def check_input_and_output_types(operator, good_input_types=None, good_output_ty
raise RuntimeError('Operator %s (type: %s) got an output %s with a wrong type %s. Only %s are allowed' \
% (operator.full_name, operator.type, variable.full_name, type(variable.type),
good_output_types))


def compare_strict_version(v1, v2):
"""
Compares two versions of ONNX.
:param v1: targeted version, usually equal to the current
version of ONNX module
:param v2: one specific ONNX version (usually one
which introduced an API change)
:return: -1 (targeted is less recent), 0 for equal
or 1 (targeted is more recent)
The function always returns 1 if v1 is None,
it is used to set the targeted version
to a development version of ONNX.
"""
if v2 is None:
raise ValueError("v2 must not be None.")
if v1 is None:
return 1
if isinstance(v1, six.string_types):
v1 = StrictVersion(v1)
if not hasattr(v1, 'version'):
return 1
if isinstance(v2, six.string_types):
v2 = StrictVersion(v2)
return -1 if v1 < v2 else (0 if v1 == v2 else 1)

Loading

0 comments on commit 93b1789

Please sign in to comment.