diff --git a/tools/circle_script.js b/tools/circle_script.js index 53c86e81a13..28371271abf 100644 --- a/tools/circle_script.js +++ b/tools/circle_script.js @@ -1,72 +1,72 @@ const path = require('path'); const flatc = require('./flatc'); -const fs = require('fs'); +const fs = require('fs').promises; -const schema = path.join(__dirname, '..', 'third_party', 'source', 'circle', 'nnpackage', 'schema', 'circle_schema.fbs'); -const file = path.join(__dirname, '..', 'source', 'circle-metadata.json'); - -const input = fs.readFileSync(file, 'utf-8'); -const json = JSON.parse(input); - -const operators = new Map(); -const attributes = new Map(); -for (const operator of json) { - if (operators.has(operator.name)) { - throw new Error("Duplicate operator '" + operator.name + "'."); - } - operators.set(operator.name, operator); - if (operator && operator.attributes) { - for (const attribute of operator.attributes) { - const name = operator.name + ':' + attribute.name; - attributes.set(name, attribute); +const main = async () => { + const schema = path.join(__dirname, '..', 'third_party', 'source', 'circle', 'nnpackage', 'schema', 'circle_schema.fbs'); + const file = path.join(__dirname, '..', 'source', 'circle-metadata.json'); + const input = await fs.readFile(file, 'utf-8'); + const json = JSON.parse(input); + const operators = new Map(); + const attributes = new Map(); + for (const operator of json) { + if (operators.has(operator.name)) { + throw new Error("Duplicate operator '" + operator.name + "'."); } - } -} - -const root = new flatc.Root('circle', [], [ schema ]); -const namespace = root.find('circle', flatc.Namespace); - -const builtOperator = namespace.find('circle.BuiltinOperator', flatc.Type); -const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]); -for (const op of builtOperator.values.keys()) { - let op_key = op === 'BATCH_MATMUL' ? 'BATCH_MAT_MUL' : op; - op_key = op_key.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join(''); - const table = namespace.find('circle.' + op_key + 'Options', flatc.Type); - if (table && table.fields.size > 0) { - if (!operators.has(op_key)) { - const operator = { name: op_key }; - operators.set(op_key, operator); - json.push(operator); + operators.set(operator.name, operator); + if (operator && operator.attributes) { + for (const attribute of operator.attributes) { + const name = operator.name + ':' + attribute.name; + attributes.set(name, attribute); + } } - const operator = operators.get(op_key); - operator.attributes = operator.attributes || []; - for (const field of table.fields.values()) { - const attr_key = op_key + ':' + field.name; - if (!attributes.has(attr_key)) { - const attribute = { name: field.name }; - attributes.set(attr_key, attribute); - operator.attributes.push(attribute); + } + const root = new flatc.Root('circle'); + await root.load([], [ schema ]); + const namespace = root.find('circle', flatc.Namespace); + const builtOperator = namespace.find('circle.BuiltinOperator', flatc.Type); + const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]); + for (const op of builtOperator.values.keys()) { + let op_key = op === 'BATCH_MATMUL' ? 'BATCH_MAT_MUL' : op; + op_key = op_key.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join(''); + const table = namespace.find('circle.' + op_key + 'Options', flatc.Type); + if (table && table.fields.size > 0) { + if (!operators.has(op_key)) { + const operator = { name: op_key }; + operators.set(op_key, operator); + json.push(operator); } - const attribute = attributes.get(attr_key); - const type = field.type; - let defaultValue = field.defaultValue; - if (type instanceof flatc.Enum) { - if (!type.keys.has(defaultValue)) { - throw new Error("Invalid '" + type.name + "' default value '" + defaultValue + "'."); + const operator = operators.get(op_key); + operator.attributes = operator.attributes || []; + for (const field of table.fields.values()) { + const attr_key = op_key + ':' + field.name; + if (!attributes.has(attr_key)) { + const attribute = { name: field.name }; + attributes.set(attr_key, attribute); + operator.attributes.push(attribute); + } + const attribute = attributes.get(attr_key); + const type = field.type; + let defaultValue = field.defaultValue; + if (type instanceof flatc.Enum) { + if (!type.keys.has(defaultValue)) { + throw new Error("Invalid '" + type.name + "' default value '" + defaultValue + "'."); + } + defaultValue = type.keys.get(defaultValue); } - defaultValue = type.keys.get(defaultValue); + attribute.type = type.name === 'bool' ? 'boolean' : type.name + (field.repeated ? '[]' : ''); + attribute.default = defaultValue; } - attribute.type = type.name === 'bool' ? 'boolean' : type.name + (field.repeated ? '[]' : ''); - attribute.default = defaultValue; } } -} + json.sort((a, b) => a.name.localeCompare(b.name)); + let output = JSON.stringify(json, null, 2); + output = output.replace(/\s {8}/g, ' '); + output = output.replace(/,\s {8}/g, ', '); + output = output.replace(/\s {6}}/g, ' }'); + await fs.writeFile(file, output, 'utf-8'); +}; -json.sort((a, b) => a.name.localeCompare(b.name)); +main(); -let output = JSON.stringify(json, null, 2); -output = output.replace(/\s {8}/g, ' '); -output = output.replace(/,\s {8}/g, ', '); -output = output.replace(/\s {6}}/g, ' }'); -fs.writeFileSync(file, output, 'utf-8'); diff --git a/tools/flatc.js b/tools/flatc.js index 38f4fd31c00..abada98c62a 100644 --- a/tools/flatc.js +++ b/tools/flatc.js @@ -1,6 +1,6 @@ const flatc = {}; -const fs = require('fs'); +const fs = require('fs').promises; const path = require('path'); flatc.Object = class { @@ -769,13 +769,18 @@ flatc.Tokenizer = class { flatc.Root = class extends flatc.Object { - constructor(root, paths, files) { + constructor(root) { super(null, root); this._namespaces = new Map(); this._files = new Set(); this.root_type = new Set(); + } + + async load(paths, files) { for (const file of files) { - this._parseFile(paths, file); + /* eslint-disable no-await-in-loop */ + await this._parseFile(paths, file); + /* eslint-enable no-await-in-loop */ } this.resolve(); } @@ -821,16 +826,20 @@ flatc.Root = class extends flatc.Object { return super.find(name, type); } - _parseFile(paths, file) { + async _parseFile(paths, file) { if (!this._files.has(file)) { this._files.add(file); - const content = fs.readFileSync(file, 'utf-8'); + const content = await fs.readFile(file, 'utf-8'); const parser = new flatc.Parser(content, file, this); const includes = parser.include(); for (const include of includes) { - const includeFile = this._resolve(paths, file, include); + /* eslint-disable no-await-in-loop */ + const includeFile = await this._resolve(paths, file, include); + /* eslint-enable no-await-in-loop */ if (includeFile) { - this._parseFile(paths, includeFile); + /* eslint-disable no-await-in-loop */ + await this._parseFile(paths, includeFile); + /* eslint-enable no-await-in-loop */ continue; } throw new flatc.Error("Include '" + include + "' not found."); @@ -839,14 +848,22 @@ flatc.Root = class extends flatc.Object { } } - _resolve(paths, origin, target) { + async _resolve(paths, origin, target) { + const exists = async (path) => { + try { + await fs.access(path); + return true; + } catch (error) { + return false; + } + }; const file = path.join(path.dirname(origin), target); - if (fs.existsSync(file)) { + if (await exists(file)) { return file; } for (const current of paths) { const file = path.join(current, target); - if (fs.existsSync(file)) { + if (await exists(file)) { return file; } } @@ -1211,8 +1228,7 @@ flatc.Error = class extends Error { } }; -const main = (args) => { - +const main = async (args) => { const options = { verbose: false, root: 'default', out: '', text: false, paths: [], files: [] }; while (args.length > 0) { const arg = args.shift(); @@ -1240,12 +1256,12 @@ const main = (args) => { break; } } - try { - const root = new flatc.Root(options.root, options.paths, options.files); + const root = new flatc.Root(options.root); + await root.load(options.paths, options.files); const generator = new flatc.Generator(root, options.text); if (options.out) { - fs.writeFileSync(options.out, generator.content, 'utf-8'); + await fs.writeFile(options.out, generator.content, 'utf-8'); } } catch (err) { if (err instanceof flatc.Error && !options.verbose) { @@ -1253,16 +1269,15 @@ const main = (args) => { } else { process.stderr.write(err.stack + '\n'); } - return 1; + process.exit(1); } - return 0; + process.exit(0); }; if (typeof process === 'object' && Array.isArray(process.argv) && process.argv.length > 1 && process.argv[1] === __filename) { const args = process.argv.slice(2); - const code = main(args); - process.exit(code); + main(args); } if (typeof module !== 'undefined' && typeof module.exports === 'object') { diff --git a/tools/megengine_script.js b/tools/megengine_script.js index 5f945722995..8e9e4e6cb49 100644 --- a/tools/megengine_script.js +++ b/tools/megengine_script.js @@ -1,112 +1,113 @@ const path = require('path'); const flatc = require('./flatc'); -const fs = require('fs'); +const fs = require('fs').promises; -const schema = path.join(__dirname, '..', 'third_party', 'source', 'megengine', 'src', 'serialization', 'fbs', 'schema_v2.fbs'); -const file = path.join(__dirname, '..', 'source', 'megengine-metadata.json'); - -const input = fs.readFileSync(file, 'utf-8'); -const json = JSON.parse(input); - -const category = { - Host2DeviceCopy: 'Data', - Dimshuffle: 'Shape', - Flip: 'Shape', - Images2Neibs: 'Shape', - Reshape: 'Shape', - Concat: 'Tensor', - GetVarShape: 'Shape', - Subtensor: 'Tensor', - Padding: 'Layer', - AdaptivePooling: 'Activation', - ConvPooling: 'Pool', - TQT: 'Quantization', - LSQ: 'Quantization', - Pooling: 'Pool', - PoolingForward: 'Pool', - AdaptivePoolingForward: 'Pool', - SlidingWindowTranspose: 'Transform', - LRN: 'Normalization', - BatchNormForward: 'Normalization', - BN: 'Normalization', - LayerNorm: 'Normalization', - Convolution: 'Layer', - ConvolutionForward: 'Layer', - Convolution3D: 'Layer', - SeparableConv: 'Layer', - SeparableConv3D: 'Layer', - ConvBiasForward: 'Layer', - ConvBias: 'Layer', - Conv3DBias: 'Layer', - Dropout: 'Dropout', - Softmax: 'Activation', - RNN: 'Layer', - RNNCell: 'Layer', - LSTM: 'Layer' -}; -const operators = new Map(); -const attributes = new Map(); -for (const operator of json) { - if (operators.has(operator.name)) { - throw new Error("Duplicate operator '" + operator.name + "'."); - } - operators.set(operator.name, operator); - if (operator && operator.attributes) { - for (const attribute of operator.attributes) { - const name = operator.name + ':' + attribute.name; - attributes.set(name, attribute); +const main = async () => { + const schema = path.join(__dirname, '..', 'third_party', 'source', 'megengine', 'src', 'serialization', 'fbs', 'schema_v2.fbs'); + const file = path.join(__dirname, '..', 'source', 'megengine-metadata.json'); + const input = await fs.readFile(file, 'utf-8'); + const json = JSON.parse(input); + const category = { + Host2DeviceCopy: 'Data', + Dimshuffle: 'Shape', + Flip: 'Shape', + Images2Neibs: 'Shape', + Reshape: 'Shape', + Concat: 'Tensor', + GetVarShape: 'Shape', + Subtensor: 'Tensor', + Padding: 'Layer', + AdaptivePooling: 'Activation', + ConvPooling: 'Pool', + TQT: 'Quantization', + LSQ: 'Quantization', + Pooling: 'Pool', + PoolingForward: 'Pool', + AdaptivePoolingForward: 'Pool', + SlidingWindowTranspose: 'Transform', + LRN: 'Normalization', + BatchNormForward: 'Normalization', + BN: 'Normalization', + LayerNorm: 'Normalization', + Convolution: 'Layer', + ConvolutionForward: 'Layer', + Convolution3D: 'Layer', + SeparableConv: 'Layer', + SeparableConv3D: 'Layer', + ConvBiasForward: 'Layer', + ConvBias: 'Layer', + Conv3DBias: 'Layer', + Dropout: 'Dropout', + Softmax: 'Activation', + RNN: 'Layer', + RNNCell: 'Layer', + LSTM: 'Layer' + }; + const operators = new Map(); + const attributes = new Map(); + for (const operator of json) { + if (operators.has(operator.name)) { + throw new Error("Duplicate operator '" + operator.name + "'."); } - } -} - -const root = new flatc.Root('megengine', [], [ schema ]); -const namespace = root.find('mgb.serialization.fbs.param', flatc.Namespace); - -const operatorParams = namespace.children; -for (const op of operatorParams) { - const op_key = op[ 0 ]; - const op_table = op[ 1 ]; - if (op_table instanceof flatc.Enum) { - continue; - } - if (op_table && op_table.fields.size > 0) { - if (!operators.has(op_key)) { - const operator = { name: op_key }; - operators.set(op_key, operator); - json.push(operator); + operators.set(operator.name, operator); + if (operator && operator.attributes) { + for (const attribute of operator.attributes) { + const name = operator.name + ':' + attribute.name; + attributes.set(name, attribute); + } } - const operator = operators.get(op_key); - if (category[ op_key.replace(/V(\d+)$/, '') ]) { - operator.category = category[ op_key.replace(/V(\d+)$/, '') ]; + } + const root = new flatc.Root('megengine'); + await root.load([], [ schema ]); + const namespace = root.find('mgb.serialization.fbs.param', flatc.Namespace); + const operatorParams = namespace.children; + for (const op of operatorParams) { + const op_key = op[ 0 ]; + const op_table = op[ 1 ]; + if (op_table instanceof flatc.Enum) { + continue; } - operator.attributes = operator.attributes || []; - for (const field of op_table.fields) { - const field_name = field[ 0 ]; - const field_table = field[ 1 ]; - const attr_key = op_key + ':' + field_name; - if (!attributes.has(attr_key)) { - const attribute = { name: field_name }; - attributes.set(attr_key, attribute); - operator.attributes.push(attribute); + if (op_table && op_table.fields.size > 0) { + if (!operators.has(op_key)) { + const operator = { name: op_key }; + operators.set(op_key, operator); + json.push(operator); + } + const operator = operators.get(op_key); + if (category[ op_key.replace(/V(\d+)$/, '') ]) { + operator.category = category[ op_key.replace(/V(\d+)$/, '') ]; } - const attribute = attributes.get(attr_key); - const type = field_table.type; - let defaultValue = field_table.defaultValue; - if (type instanceof flatc.Enum) { - if (!type.keys.has(defaultValue)) { - throw new Error("Invalid '" + type.name + "' default value '" + defaultValue + "'."); + operator.attributes = operator.attributes || []; + for (const field of op_table.fields) { + const field_name = field[ 0 ]; + const field_table = field[ 1 ]; + const attr_key = op_key + ':' + field_name; + if (!attributes.has(attr_key)) { + const attribute = { name: field_name }; + attributes.set(attr_key, attribute); + operator.attributes.push(attribute); } - defaultValue = type.keys.get(defaultValue); + const attribute = attributes.get(attr_key); + const type = field_table.type; + let defaultValue = field_table.defaultValue; + if (type instanceof flatc.Enum) { + if (!type.keys.has(defaultValue)) { + throw new Error("Invalid '" + type.name + "' default value '" + defaultValue + "'."); + } + defaultValue = type.keys.get(defaultValue); + } + attribute.type = type.name + (field.repeated ? '[]' : ''); + attribute.default = defaultValue; } - attribute.type = type.name + (field.repeated ? '[]' : ''); - attribute.default = defaultValue; } } -} -// json.sort((a, b) => a.name.localeCompare(b.name)) + // json.sort((a, b) => a.name.localeCompare(b.name)) + + let output = JSON.stringify(json, null, 2); + output = output.replace(/\s {8}/g, ' '); + output = output.replace(/,\s {8}/g, ', '); + output = output.replace(/\s {6}}/g, ' }'); + await fs.writeFile(file, output, 'utf-8'); +}; -let output = JSON.stringify(json, null, 2); -output = output.replace(/\s {8}/g, ' '); -output = output.replace(/,\s {8}/g, ', '); -output = output.replace(/\s {6}}/g, ' }'); -fs.writeFileSync(file, output, 'utf-8'); +main(); \ No newline at end of file diff --git a/tools/mslite_metadata.js b/tools/mslite_metadata.js index d6e50ebe208..e34104b012c 100644 --- a/tools/mslite_metadata.js +++ b/tools/mslite_metadata.js @@ -1,81 +1,80 @@ const path = require('path'); const flatc = require('./flatc'); -const fs = require('fs'); +const fs = require('fs').promises; -const schema = path.join(__dirname, '..', 'third_party', 'source', 'mindspore', 'mindspore', 'lite', 'schema', 'ops.fbs'); -const file = path.join(__dirname, '..', 'source', 'mslite-metadata.json'); - -const input = fs.readFileSync(file, 'utf-8'); -const json = JSON.parse(input); - -const operators = new Map(); -const attributes = new Map(); -for (const operator of json) { - if (operators.has(operator.name)) { - throw new Error("Duplicate operator '" + operator.name + "'."); - } - operators.set(operator.name, operator); - if (operator && operator.attributes) { - for (const attribute of operator.attributes) { - const name = operator.name + ':' + attribute.name; - attributes.set(name, attribute); +const main = async () => { + const schema = path.join(__dirname, '..', 'third_party', 'source', 'mindspore', 'mindspore', 'lite', 'schema', 'ops.fbs'); + const file = path.join(__dirname, '..', 'source', 'mslite-metadata.json'); + const input = await fs.readFile(file, 'utf-8'); + const json = JSON.parse(input); + const operators = new Map(); + const attributes = new Map(); + for (const operator of json) { + if (operators.has(operator.name)) { + throw new Error("Duplicate operator '" + operator.name + "'."); } - } -} - -const root = new flatc.Root('mslite', [], [ schema ]); -const namespace = root.find('mindspore.schema', flatc.Namespace); - -const primitiveType = namespace.find('mindspore.schema.PrimitiveType', flatc.Type); -for (const table of primitiveType.values.values()) { - const op_key = table.name; - if (!operators.has(op_key)) { - const operator = { name: op_key }; - operators.set(op_key, operator); - json.push(operator); - } - const operator = operators.get(op_key); - if (table && table.fields.size > 0) { - operator.attributes = operator.attributes || []; - const inputs = operator.inputs; - const outputs = operator.outputs; - delete operator.inputs; - delete operator.outputs; - if (inputs) { - operator.inputs = inputs; + operators.set(operator.name, operator); + if (operator && operator.attributes) { + for (const attribute of operator.attributes) { + const name = operator.name + ':' + attribute.name; + attributes.set(name, attribute); + } } - if (outputs) { - operator.outputs = outputs; + } + const root = new flatc.Root('mslite'); + await root.load([], [ schema ]); + const namespace = root.find('mindspore.schema', flatc.Namespace); + const primitiveType = namespace.find('mindspore.schema.PrimitiveType', flatc.Type); + for (const table of primitiveType.values.values()) { + const op_key = table.name; + if (!operators.has(op_key)) { + const operator = { name: op_key }; + operators.set(op_key, operator); + json.push(operator); } - for (const field of table.fields.values()) { - const attr_key = op_key + ':' + field.name; - if (!attributes.has(attr_key)) { - const attribute = { name: field.name }; - attributes.set(attr_key, attribute); - operator.attributes.push(attribute); + const operator = operators.get(op_key); + if (table && table.fields.size > 0) { + operator.attributes = operator.attributes || []; + const inputs = operator.inputs; + const outputs = operator.outputs; + delete operator.inputs; + delete operator.outputs; + if (inputs) { + operator.inputs = inputs; } - const attribute = attributes.get(attr_key); - const type = field.type; - let defaultValue = field.defaultValue; - if (type instanceof flatc.Enum) { - if (!type.keys.has(defaultValue)) { - throw new Error("Invalid '" + type.name + "' default value '" + defaultValue + "'."); - } - defaultValue = type.keys.get(defaultValue); + if (outputs) { + operator.outputs = outputs; } - attribute.type = type.name === 'bool' ? 'boolean' : type.name + (field.repeated ? '[]' : ''); - if (attribute.default === undefined) { - attribute.default = defaultValue; + for (const field of table.fields.values()) { + const attr_key = op_key + ':' + field.name; + if (!attributes.has(attr_key)) { + const attribute = { name: field.name }; + attributes.set(attr_key, attribute); + operator.attributes.push(attribute); + } + const attribute = attributes.get(attr_key); + const type = field.type; + let defaultValue = field.defaultValue; + if (type instanceof flatc.Enum) { + if (!type.keys.has(defaultValue)) { + throw new Error("Invalid '" + type.name + "' default value '" + defaultValue + "'."); + } + defaultValue = type.keys.get(defaultValue); + } + attribute.type = type.name === 'bool' ? 'boolean' : type.name + (field.repeated ? '[]' : ''); + if (attribute.default === undefined) { + attribute.default = defaultValue; + } } } } -} - -json.sort((a, b) => a.name.localeCompare(b.name)); + json.sort((a, b) => a.name.localeCompare(b.name)); + let output = JSON.stringify(json, null, 2); + output = output.replace(/\s {8}/g, ' '); + output = output.replace(/,\s {8}/g, ', '); + output = output.replace(/\s {6}}/g, ' }'); + await fs.writeFile(file, output, 'utf-8'); +}; -let output = JSON.stringify(json, null, 2); -output = output.replace(/\s {8}/g, ' '); -output = output.replace(/,\s {8}/g, ', '); -output = output.replace(/\s {6}}/g, ' }'); -fs.writeFileSync(file, output, 'utf-8'); +main(); \ No newline at end of file diff --git a/tools/tflite_metadata.js b/tools/tflite_metadata.js index 090a9a7e91b..23a2e48a4df 100644 --- a/tools/tflite_metadata.js +++ b/tools/tflite_metadata.js @@ -1,72 +1,71 @@ const path = require('path'); const flatc = require('./flatc'); -const fs = require('fs'); +const fs = require('fs').promises; -const schema = path.join(__dirname, '..', 'third_party', 'source', 'tensorflow', 'tensorflow', 'lite', 'schema', 'schema.fbs'); -const file = path.join(__dirname, '..', 'source', 'tflite-metadata.json'); - -const input = fs.readFileSync(file, 'utf-8'); -const json = JSON.parse(input); - -const operators = new Map(); -const attributes = new Map(); -for (const operator of json) { - if (operators.has(operator.name)) { - throw new Error("Duplicate operator '" + operator.name + "'."); - } - operators.set(operator.name, operator); - if (operator && operator.attributes) { - for (const attribute of operator.attributes) { - const name = operator.name + ':' + attribute.name; - attributes.set(name, attribute); +const main = async () => { + const schema = path.join(__dirname, '..', 'third_party', 'source', 'tensorflow', 'tensorflow', 'lite', 'schema', 'schema.fbs'); + const file = path.join(__dirname, '..', 'source', 'tflite-metadata.json'); + const input = await fs.readFile(file, 'utf-8'); + const json = JSON.parse(input); + const operators = new Map(); + const attributes = new Map(); + for (const operator of json) { + if (operators.has(operator.name)) { + throw new Error("Duplicate operator '" + operator.name + "'."); } - } -} - -const root = new flatc.Root('tflite', [], [ schema ]); -const namespace = root.find('tflite', flatc.Namespace); - -const builtOperator = namespace.find('tflite.BuiltinOperator', flatc.Type); -const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]); -for (const op of builtOperator.values.keys()) { - let op_key = op === 'BATCH_MATMUL' ? 'BATCH_MAT_MUL' : op; - op_key = op_key.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join(''); - const table = namespace.find('tflite.' + op_key + 'Options', flatc.Type); - if (table && table.fields.size > 0) { - if (!operators.has(op_key)) { - const operator = { name: op_key }; - operators.set(op_key, operator); - json.push(operator); + operators.set(operator.name, operator); + if (operator && operator.attributes) { + for (const attribute of operator.attributes) { + const name = operator.name + ':' + attribute.name; + attributes.set(name, attribute); + } } - const operator = operators.get(op_key); - operator.attributes = operator.attributes || []; - for (const field of table.fields.values()) { - const attr_key = op_key + ':' + field.name; - if (!attributes.has(attr_key)) { - const attribute = { name: field.name }; - attributes.set(attr_key, attribute); - operator.attributes.push(attribute); + } + const root = new flatc.Root('tflite'); + await root.load([], [ schema ]); + const namespace = root.find('tflite', flatc.Namespace); + const builtOperator = namespace.find('tflite.BuiltinOperator', flatc.Type); + const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]); + for (const op of builtOperator.values.keys()) { + let op_key = op === 'BATCH_MATMUL' ? 'BATCH_MAT_MUL' : op; + op_key = op_key.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join(''); + const table = namespace.find('tflite.' + op_key + 'Options', flatc.Type); + if (table && table.fields.size > 0) { + if (!operators.has(op_key)) { + const operator = { name: op_key }; + operators.set(op_key, operator); + json.push(operator); } - const attribute = attributes.get(attr_key); - const type = field.type; - let defaultValue = field.defaultValue; - if (type instanceof flatc.Enum) { - if (!type.keys.has(defaultValue)) { - throw new Error("Invalid '" + type.name + "' default value '" + defaultValue + "'."); + const operator = operators.get(op_key); + operator.attributes = operator.attributes || []; + for (const field of table.fields.values()) { + const attr_key = op_key + ':' + field.name; + if (!attributes.has(attr_key)) { + const attribute = { name: field.name }; + attributes.set(attr_key, attribute); + operator.attributes.push(attribute); } - defaultValue = type.keys.get(defaultValue); + const attribute = attributes.get(attr_key); + const type = field.type; + let defaultValue = field.defaultValue; + if (type instanceof flatc.Enum) { + if (!type.keys.has(defaultValue)) { + throw new Error("Invalid '" + type.name + "' default value '" + defaultValue + "'."); + } + defaultValue = type.keys.get(defaultValue); + } + attribute.type = type.name === 'bool' ? 'boolean' : type.name + (field.repeated ? '[]' : ''); + attribute.default = defaultValue; } - attribute.type = type.name === 'bool' ? 'boolean' : type.name + (field.repeated ? '[]' : ''); - attribute.default = defaultValue; } } -} - -json.sort((a, b) => a.name.localeCompare(b.name)); + json.sort((a, b) => a.name.localeCompare(b.name)); + let output = JSON.stringify(json, null, 2); + output = output.replace(/\s {8}/g, ' '); + output = output.replace(/,\s {8}/g, ', '); + output = output.replace(/\s {6}}/g, ' }'); + await fs.writeFile(file, output, 'utf-8'); +}; -let output = JSON.stringify(json, null, 2); -output = output.replace(/\s {8}/g, ' '); -output = output.replace(/,\s {8}/g, ', '); -output = output.replace(/\s {6}}/g, ' }'); -fs.writeFileSync(file, output, 'utf-8'); +main(); \ No newline at end of file