diff --git a/source/python.js b/source/python.js index 19e3c3c5a7..f44c6adb95 100644 --- a/source/python.js +++ b/source/python.js @@ -1625,6 +1625,19 @@ python.Execution = class { get(key, defaultValue) { return super.get(key) || defaultValue; } + setdefault(key, defaultValue) { + if (this.has(key)) { + return this.get(key); + } + const value = defaultValue || null; + this.set(key, value); + return value; + } + update(other) { + for (const [key, value] of other) { + this.set(key, value); + } + } }; this._modules = new dict(); this._registry = new Map(); @@ -1808,6 +1821,7 @@ python.Execution = class { this.registerFunction('operator.add'); this.registerFunction('operator.eq'); this.registerFunction('operator.ge'); + this.registerFunction('operator.getitem'); this.registerFunction('operator.gt'); this.registerFunction('operator.mul'); this.registerFunction('operator.mod'); @@ -4845,8 +4859,9 @@ python.Execution = class { this.name = name; this.op = op; this.target = target; - this._input_nodes = new Map(); - this.users = new Map(); + this._input_nodes = new builtins.dict(); + this.__update_args_kwargs(args, kwargs); + this.users = new builtins.dict(); this.type = return_type; this._prev = this; this._next = this; @@ -4854,6 +4869,9 @@ python.Execution = class { this._repr_fn = null; this.meta = new builtins.dict(); } + get args() { + return this._args; + } get next() { return this._next; } @@ -4867,12 +4885,52 @@ python.Execution = class { const [p, n] = [this._prev, this._next]; [p._next, n._prev] = [n, p]; } + __update_args_kwargs(new_args, new_kwargs) { + const update_users_and_input_nodes = (n) => { + if (n instanceof torch.fx.node.Node) { + this._input_nodes.setdefault(n); + n.users.setdefault(this); + } + return n; + }; + const map_aggregate = (a, fn) => { + if (a instanceof builtins.tuple) { + const t = new builtins.tuple(a.map((elem) => map_aggregate(elem, fn))); + if (!builtins.hasattr(a, '_fields')) { + return t; + } + throw new python.Error('Not implemented.'); + // return type(a)(*t); + } else if (Array.isArray(a)) { + return a.map((elem) => map_aggregate(elem, fn)); + } else if (a instanceof builtins.dict) { + const rv = new builtins.dict(); + for (const [k, v] of a) { + rv.__setitem__(k, map_aggregate(v, fn)); + } + return rv; + } else if (a instanceof builtins.slice) { + throw new python.Error('Not implemented.'); + // return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) + } + return fn(a); + }; + for (const old_use of this._input_nodes.keys()) { + old_use.users.pop(this); + } + // object.__setattr__(self, "_input_nodes", {}) + this._input_nodes = new builtins.dict(); + // object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes)) + this._args = map_aggregate(new_args, update_users_and_input_nodes); + // object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)) + this._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes); + } }); torch.fx.Node = torch.fx.node.Node; torch.fx.graph.Node = torch.fx.node.Node; this.registerType('torch.fx.graph.Graph', class { constructor() { - this._root = new torch.fx.node.Node(self, '', 'root', '', [], {}); + this._root = new torch.fx.node.Node(self, '', 'root', '', new builtins.list(), new builtins.dict()); this._used_names = new Map(); this._len = 0; this._graph_namespace = new torch.fx.graph._Namespace(); @@ -4896,8 +4954,8 @@ python.Execution = class { return this.create_node('placeholder', name, args, type_expr); } create_node(op, target, args, kwargs, name, type_expr) { - args = args || []; - kwargs = kwargs || {}; + args = args || new builtins.tuple(); + kwargs = kwargs || new builtins.dict(); const candidate = name || this._target_to_str(target); name = this._graph_namespace.create_name(candidate, null); const n = new torch.fx.node.Node(this, name, op, target, args, kwargs, type_expr); @@ -6066,16 +6124,51 @@ python.Execution = class { } }); + this.registerType('torch.Type', class {}); + this.registerType('torch.ClassType', class extends torch.Type { + constructor(qualified_name, cu, is_module) { + super(); + this._qualified_name = qualified_name; + this._is_module = is_module; + } + qualified_name() { + return this._qualified_name; + } + name() { + return this._qualified_name.split('.').pop(); + } + is_module() { + return this._is_module; + } + addMethod(/* name, fn */) { + } + addAttribute(/* name */) { + } + hasAttribute(/* name */) { + } + hasConstant(/* name */) { + } + methods() { + } + }); + this.registerType('torch.TupleType', class extends torch.Type {}); + this.registerType('torch.TensorType', class extends torch.Type {}); + this.registerType('torch.IntType', class extends torch.Type {}); this.registerType('torch.Argument', class { - constructor(name, type, default_value /*, alias_info, is_type_dispatched */) { + constructor(name, type, real_type, N, default_value /*, alias_info, is_type_dispatched */) { // torch/aten/src/ATen/core/function_schema.h this.name = name; this.type = type; + this.real_type = real_type; + this.N = N; this.default_value = default_value; // kwarg_only: bool // is_out: bool // alias_info: Optional[AliasInfo] } + has_default_value() { + return this.default_value !== undefined; + } }); this.registerType('torch.FunctionSchema', class { constructor(name, overload_name, args, returns) { @@ -6560,10 +6653,29 @@ python.Execution = class { */ } }); + this.registerType('torch._export.serde.schema.UserOutputSpec', class { + constructor(obj) { + this.arg = new torch._export.serde.schema.Argument(obj.arg); + } + }); this.registerType('torch._export.serde.schema.OutputSpec', class extends torch._export.serde.union._Union { constructor(obj) { super(obj); - Object.assign(this, { ...obj }); + if (this.type === 'user_output') { + this.user_output = new torch._export.serde.schema.UserOutputSpec(this.user_output); + } else if (this.type === 'loss_output') { + this.loss_output = new torch._export.serde.schema.LossOutputSpec(this.loss_output); + } else if (this.type === 'buffer_mutation') { + this.buffer_mutation = new torch._export.serde.schema.BufferMutationSpec(this.buffer_mutation); + } else if (this.type === 'gradient_to_parameter') { + this.gradient_to_parameter = new torch._export.serde.schema.GradientToParameterSpec(this.gradient_to_parameter); + } else if (this.type === 'gradient_to_user_input') { + this.gradient_to_user_input = new torch._export.serde.schema.GradientToUserInputSpec(this.gradient_to_user_input); + } else if (this.type === 'user_input_mutation') { + this.user_input_mutation = new torch._export.serde.schema.UserInputMutationSpec(this.user_input_mutation); + } else if (this.type === 'token') { + this.token = new torch._export.serde.schema.OutputTokenSpec(this.token); + } } }); this.registerType('torch._export.serde.schema.TensorArgument', class { @@ -6677,16 +6789,19 @@ python.Execution = class { throw new python.Error(`Unsupported graph node ${output.type}.`); } deserialize_graph(serialized_graph) { - for (const [name, tensor_value] of Object.entries(serialized_graph.tensor_values)) { + for (const [name, tensor_value] of serialized_graph.tensor_values) { const meta_val = this.deserialize_tensor_meta(tensor_value.meta || tensor_value, this.fake_tensor_mode); this.serialized_name_to_meta.set(name, meta_val); } - for (const [name, sym_int_value] of Object.entries(serialized_graph.sym_int_values)) { + for (const [name, sym_int_value] of serialized_graph.sym_int_values) { this.serialized_name_to_meta.set(name, this.deserialize_sym_int(sym_int_value)); } - for (const [name, sym_bool_value] in Object.entries(serialized_graph.sym_bool_values)) { + for (const [name, sym_bool_value] of serialized_graph.sym_bool_values) { this.serialized_name_to_meta.set(name, this.deserialize_sym_bool(sym_bool_value)); } + for (const [name, script_obj_meta] of serialized_graph.custom_obj_values) { + this.serialized_name_to_meta.set(name, this.deserialize_script_obj_meta(script_obj_meta)); + } for (let i = 0; i < serialized_graph.inputs.length; i++) { const input = serialized_graph.inputs[i]; if (input.type === 'as_tensor' || input.type === 'as_sym_int' || input.type === 'as_custom_obj') { @@ -6895,102 +7010,118 @@ python.Execution = class { throw new python.Error(`Node ${name} has already been deserialized before.`); } this.serialized_name_to_node.set(name, fx_node); - fx_node.meta.val = this.serialized_name_to_meta.get(name); + fx_node.meta.set('val', this.serialized_name_to_meta.get(name)); } deserialize_sym_op_inputs(inputs) { return inputs.map((input) => this.deserialize_input(input.arg)); } - deserialize_inputs(target /* , serialized_node */) { + deserialize_inputs(target, serialized_node) { const schema_args = target._schema.arguments; - const actual_args = null; - /* - actual_args = { - input.name: this.deserialize_input(input.arg) for input in serialized_node.inputs - } - */ + const actual_args = new Map(serialized_node.inputs.map((input) => [input.name, this.deserialize_input(input.arg)])); const args = []; const kwargs = {}; for (const schema_arg of schema_args) { const is_positional = !schema_arg.has_default_value() && !schema_arg.kwarg_only; if (is_positional) { - args.push(actual_args[schema_arg.name]); - } else if (schema_arg.name in actual_args) { - kwargs[schema_arg.name] = actual_args[schema_arg.name]; + args.push(actual_args.get(schema_arg.name)); + } else if (actual_args.has(schema_arg.name)) { + kwargs[schema_arg.name] = actual_args.get(schema_arg.name); } } return [args, kwargs]; } - deserialize_input(/* inp */) { - /* - value = inp.value - typ_ = inp.type - if typ_ === 'as_none': - # None should converted as None, but is encoded as bool in serialized - # Convert serialized object to torch equivalent - return None - elif typ_ === 'as_tensor': - return this.serialized_name_to_node[inp.as_tensor.name] - elif typ_ === 'as_scalar_type': - return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type] - elif typ_ === 'as_memory_format': - return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format] - elif typ_ === 'as_layout': - return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout] - elif typ_ === 'as_graph': - assert isinstance(value, GraphArgument) + deserialize_input(inp) { + const value = inp.value; + const typ_ = inp.type; + if (typ_ === 'as_none') { + return null; + } else if (typ_ === 'as_tensor') { + return this.serialized_name_to_node.get(inp.as_tensor.name); + } else if (typ_ === 'as_scalar_type') { + return torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type]; + } else if (typ_ === 'as_memory_format') { + return torch._export.serde.serialize._SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format]; + } else if (typ_ === 'as_layout') { + return torch._export.serde.serialize._SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout]; + } else if (typ_ === 'as_graph') { + /* assert isinstance(value, GraphArgument) with this.save_graph_module(): this.deserialize_graph(value.graph) - submodule = torch._export.exported_program._create_graph_module_for_export(this.module, this.graph) + submodule = ep._create_graph_module_for_export(this.module, this.graph) this.module.register_module(value.name, submodule) return this.graph.create_node( 'get_attr', value.name, name=value.name, - ) - elif typ_ === 'as_device': - return deserialize_device(inp.as_device) - elif typ_ === 'as_int': - return inp.as_int - elif typ_ === 'as_float': - return inp.as_float - elif typ_ === 'as_bool': - return inp.as_bool - elif typ_ === 'as_string': - return inp.as_string - elif typ_ === 'as_sym_int': - return this.deserialize_sym_argument(inp.as_sym_int) - elif typ_ === 'as_sym_bool': - return this.deserialize_sym_argument(inp.as_sym_bool) - elif isinstance(value, list): - if len(value) === 0: - return [] - elif isinstance(value[0], TensorArgument): - result = [] - for arg in value: - result.append(this.serialized_name_to_node[arg.name]) - return result - elif isinstance(value[0], (int, float, bool)): - # convert from serialized.python.types.List to python list - return list(value) - elif isinstance(value[0], (SymIntArgument, SymBoolArgument)): - return [this.deserialize_sym_argument(arg) for arg in value] - elif isinstance(value[0], OptionalTensorArgument): - def deserialize_optional_tensor_args(a): - if a.type === 'as_none': - return None - elif a.type === 'as_tensor': - return this.serialized_name_to_node[a.value] - else: - raise SerializeError(f'Unhandled argument {inp}') - return list(map(deserialize_optional_tensor_args, value)) - else: - raise SerializeError(f'Unhandled argument {inp}') - elif typ_ === 'as_custom_obj': - return this.constants[inp.as_custom_obj.name] - else { - raise SerializeError(`Unhandled argument ${inp}.`); + )*/ + throw new Error(); + } else if (typ_ === 'as_device') { + return this.deserialize_device(inp.as_device); + } else if (typ_ === 'as_int') { + return inp.as_int; + } else if (typ_ === 'as_float') { + return inp.as_float; + } else if (typ_ === 'as_bool') { + return inp.as_bool; + } else if (typ_ === 'as_string') { + return inp.as_string; + } else if (typ_ === 'as_sym_int') { + return this.deserialize_sym_argument(inp.as_sym_int); + } else if (typ_ === 'as_sym_bool') { + return this.deserialize_sym_argument(inp.as_sym_bool); + } else if (Array.isArray(value)) { + if (value.length === 0) { + return []; + } else if (typ_ === 'as_tensors') { + const result = []; + for (const arg of value) { + result.append(this.serialized_name_to_node.get(arg.name)); + } + return result; + } else if (typ_ === 'as_ints' || typ_ === 'as_floats' || typ_ === 'as_bools' || typ_ === 'as_strings') { + return Array.from(value); + } else if (typ_ === 'as_sym_ints' || typ_ === 'as_sym_bools') { + return value.map((arg) => this.deserialize_sym_argument(arg)); + } else if (typ_ === 'as_optional_tensors') { + const deserialize_optional_tensor_args = (a) => { + if (a.type === 'as_none') { + return null; + } else if (a.type === 'as_tensor') { + return this.serialized_name_to_node.get(a.value.name); + } + throw new python.Error(`Unsupported argument '${typ_}'.`); + }; + return value.map((item) => deserialize_optional_tensor_args(item)); + } + throw new python.Error(`Unsupported argument '${typ_}'.`); + } else if (typ_ === 'as_custom_obj') { + if (this.serialized_name_to_node.has(inp.as_custom_obj.name)) { + return this.serialized_name_to_node.get(inp.as_custom_obj.name); + } + return this.constants[inp.as_custom_obj.name]; + } else if (typ_ === 'as_operator') { + return this.deserialize_operator(inp.as_operator); + } + throw new python.Error(`Unsupported argument '${typ_}'.`); + } + deserialize_sym_argument(sym_arg) { + if (sym_arg instanceof torch._export.serde.schema.SymIntArgument) { + if (sym_arg.type === 'as_int') { + return sym_arg.as_int; + } else if (sym_arg.type === 'as_name') { + return this.serialized_name_to_node[sym_arg.as_name]; + } + } else if (sym_arg instanceof torch._export.serde.schema.SymBoolArgument) { + if (sym_arg.type === 'as_bool') { + return sym_arg.as_bool; + } else if (sym_arg.type === 'as_name') { + return self.serialized_name_to_node[sym_arg.as_name]; + } } - */ + throw new python.Error(`Unsupported symbolic argument type '${sym_arg.type}`); + } + deserialize_sym_op_outputs(serialized_node, fx_node) { + this.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node); } deserialize_outputs(serialized_node, fx_node) { if (serialized_node.outputs.length === 0) { @@ -7008,14 +7139,66 @@ python.Execution = class { } this.deserialize_multiple_outputs(serialized_node, fx_node); } - deserialize_multiple_outputs() { - // debugger; + deserialize_multiple_outputs(serialized_node, fx_node) { + const deserialized_metadata = this.deserialize_metadata(serialized_node.metadata); + const generate_getitem = (meta_val, fx_node, arg, idx) => { + let name = ''; + if (arg instanceof torch._export.serde.schema.TensorArgument) { + name = arg.name; + } else if (arg instanceof torch._export.serde.schema.SymIntArgument) { + name = arg.as_name; + } else { + throw new python.Error(`Unsupported argument type '${arg}'.`); + } + const individual_output = this.graph.create_node( + 'call_function', + operator.getitem, + new builtins.tuple([fx_node, idx]), + name, + ); + this.sync_fx_node(name, individual_output); + meta_val.push(this.serialized_name_to_meta.get(name)); + individual_output.meta.update(deserialized_metadata); + }; + const generate_getitems = (meta_val, fx_node, args) => { + for (let idx = 0; idx < args.length; idx++) { + let arg = args[idx]; + if (arg instanceof torch._export.serde.schema.Argument) { + arg = arg.value; + } + if (arg instanceof torch._export.serde.schema.TensorArgument || arg instanceof torch._export.serde.schema.SymIntArgument) { + generate_getitem(meta_val, fx_node, arg, idx); + } else if (Array.isArray(arg)) { // arg instanceof (list, tuple)) + const list_output = this.graph.create_node( + 'call_function', + operator.getitem, + (fx_node, idx), + ); + meta_val.append([]); + generate_getitems(meta_val[-1], list_output, arg); + list_output.meta.update(deserialized_metadata); + list_output.meta.set('val', meta_val[-1]); + } else { + throw new python.Error(`Unsupported node output type: '${arg}'.`); + } + } + }; + const meta_val = []; + if (serialized_node.outputs.length === 1) { + // assert isinstance(serialized_node.outputs[0].value, list) + // assert isinstance(serialized_node.outputs[0].value[0], TensorArgument) + generate_getitems(meta_val, fx_node, serialized_node.outputs[0].as_tensors); + } else { + generate_getitems(meta_val, fx_node, serialized_node.outputs); + } + fx_node.meta.set('val', new builtins.tuple(meta_val)); + this.serialized_name_to_node.set(fx_node.name, fx_node); } deserialize_metadata(metadata) { - const ret = {}; + const ret = new builtins.dict(); const stack_trace = metadata.stack_trace; if (stack_trace) { - ret.stack_trace = stack_trace; + ret.set('stack_trace', stack_trace); } const deserialize_meta_func = (serialized_target) => { let module = null; @@ -7044,7 +7227,7 @@ python.Execution = class { return [key, [path, ty]]; }; const nn_module_stack = new Map(nn_module_stack_str.split(';').map((item) => import_nn_module_stack(...item.split(',')))); - ret.nn_module_stack = nn_module_stack; + ret.set('nn_module_stack', nn_module_stack); } const source_fn_st_str = metadata.source_fn_stack; if (source_fn_st_str) { @@ -7053,16 +7236,16 @@ python.Execution = class { const [name, target_str] = source_fn_str.split(','); source_fn_st.push([name, deserialize_meta_func(target_str)]); } - ret.source_fn_stack = source_fn_st; + ret.set('source_fn_stack', source_fn_st); } return ret; } deserialize_argument_spec(x) { - if (x.type === "as_tensor") { + if (x.type === 'as_tensor') { return new torch.export.graph_signature.TensorArgument(x.as_tensor.name); - } else if (x.type === "as_sym_int") { + } else if (x.type === 'as_sym_int') { return new torch.export.graph_signature.SymIntArgument(x.as_sym_int.as_name); - } else if (x.type === "as_custom_obj") { + } else if (x.type === 'as_custom_obj') { return new torch.export.graph_signature.ConstantArgument(x.as_custom_obj.name, this.deserialize_input(x)); } return new torch.export.graph_signature.ConstantArgument('', this.deserialize_input(x)); @@ -7071,7 +7254,7 @@ python.Execution = class { const sizes = tensor_meta.sizes.map((val) => this.deserialize_sym_int(val)); const strides = tensor_meta.strides.map((val) => this.deserialize_sym_int(val)); const device = this.deserialize_device(tensor_meta.device); - const dtype = null; // _SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype], + const dtype = torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype]; return torch.empty_strided(sizes, strides, dtype, null, device); } deserialize_sym_int(s) { @@ -7110,10 +7293,10 @@ python.Execution = class { throw new python.Error('SymInt has invalid field type.'); } deserialize_device(d) { - if (d.index !== undefined) { - return new torch.device(d.type, d.index); + if (d.index === null) { + return new torch.device(d.type); } - return new torch.device(d.type); + return new torch.device(d.type, d.index); } _get_schema_from_target(target) { if (target instanceof torch._ops.OpOverload) { @@ -7170,6 +7353,7 @@ python.Execution = class { } } }); + this.registerType('torch.memory_format', class {}); this.registerType('torch.dtype', class { constructor(scalar_type, name, itemsize) { this._scalar_type = scalar_type; @@ -7607,6 +7791,24 @@ python.Execution = class { torch.uint16 = new torch.dtype(27, 'uint16', 2); torch.uint32 = new torch.dtype(28, 'uint32', 4); torch.uint64 = new torch.dtype(29, 'uint64', 8); + torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE = Object.fromEntries([ + ['uint8', 'BYTE'], + ['int8', 'CHAR'], ['int16', 'SHORT'], ['int32', 'INT'], ['int64', 'LONG'], + ['float16', 'HALF'], ['float32', 'FLOAT'], ['float64', 'DOUBLE'], + ['complex32', 'COMPLEXHALF'], ['complex64', 'COMPLEXFLOAT'], ['complex128', 'COMPLEXDOUBLE'], + ['bool', 'BOOL'], + ['bfloat16', 'BFLOAT16'] + ].map(([key, value]) => [torch._export.serde.schema.ScalarType[value], torch[key]])); + torch.contiguous_format = new torch.memory_format(); + torch.channels_last = new torch.memory_format(); + torch.channels_last_3d = new torch.memory_format(); + torch.preserve_format = new torch.memory_format(); + torch._export.serde.serialize._SERIALIZE_TO_TORCH_MEMORY_FORMAT = Object.fromEntries([ + ['contiguous_format', 'ContiguousFormat'], + ['channels_last', 'ChannelsLast'], + ['channels_last_3d', 'ChannelsLast3d'], + ['preserve_format', 'PreserveFormat'] + ].map(([key, value]) => [torch._export.serde.schema.MemoryFormat[value], torch[key]])); /* eslint-enable no-multi-assign */ torch.strided = new torch.layout('torch.strided'); torch.sparse_coo = new torch.layout('torch.sparse_coo'); @@ -7615,6 +7817,15 @@ python.Execution = class { torch.sparse_bsr = new torch.layout('torch.sparse_bsr'); torch.sparse_bsc = new torch.layout('torch.sparse_bsc'); torch._mkldnn = new torch.layout('torch._mkldnn'); + torch._export.serde.serialize._SERIALIZE_TO_TORCH_LAYOUT = Object.fromEntries([ + ['sparse_coo', 'SparseCoo'], + ['sparse_csr', 'SparseCsr'], + ['sparse_csc', 'SparseCsc'], + ['sparse_bsr', 'SparseBsr'], + ['sparse_bsc', 'SparseBsc'], + ['_mkldnn', '_mkldnn'], + ['strided', 'Strided'], + ].map(([key, value]) => [torch._export.serde.schema.Layout[value], torch[key]])); torch.per_tensor_affine = new torch.qscheme('torch.per_tensor_affine'); torch.per_channel_affine = new torch.qscheme('torch.per_channel_affine'); torch.per_tensor_symmetric = new torch.qscheme('torch.per_tensor_symmetric'); diff --git a/source/pytorch.js b/source/pytorch.js index 9e8f9fd9ef..564271895b 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -419,6 +419,14 @@ pytorch.Node = class { } else if (pytorch.Utility.isInstance(obj, 'torch.fx.node.Node')) { if (obj.op === 'call_function') { this.type = createType(metadata, obj.target.name); + for (const arg of obj.args) { + if (pytorch.Utility.isInstance(arg, 'torch.fx.node.Node')) { + const values = []; + this.inputs.push(new pytorch.Argument('', values)); + } else { + this.inputs.push(new pytorch.Argument('', arg, 'attribute')); + } + } } else if (obj.op === 'placeholder') { this.type = createType(metadata, 'placeholder'); } else { @@ -1364,34 +1372,6 @@ pytorch.Execution = class extends python.Execution { return this.set(name, storage); } }); - this.registerType('torch.Type', class {}); - this.registerType('torch.ClassType', class extends torch.Type { - constructor(qualified_name, cu, is_module) { - super(); - this._qualified_name = qualified_name; - this._is_module = is_module; - } - qualified_name() { - return this._qualified_name; - } - name() { - return this._qualified_name.split('.').pop(); - } - is_module() { - return this._is_module; - } - addMethod(/* name, fn */) { - } - addAttribute(/* name */) { - } - hasAttribute(/* name */) { - } - hasConstant(/* name */) { - } - methods() { - } - }); - this.registerType('torch.TupleType', class extends torch.Type {}); this.registerType('torch.ScriptFunction', class { constructor(name, graph /*, function_creator */) { this._name = name; @@ -1685,7 +1665,9 @@ pytorch.Execution = class extends python.Execution { for (const [name, type] of metadata._types) { if (name.indexOf('::') !== -1) { const [name, overload_name] = type.name.split('.'); - const schema = new torch.FunctionSchema(name, overload_name || '', [], []); + const args = type.inputs.map((arg) => new torch.Argument(arg.name)); + const returns = type.outputs.map((arg) => new torch.Argument(arg.name)); + const schema = new torch.FunctionSchema(name, overload_name || '', args, returns); const op = new torch._C.Operator(schema); registry.registerOperator(op); modules.add(type.name.split('::')[0]);