diff --git a/source/mlir.js b/source/mlir.js index 1f5a62d4ab..bdc484f73b 100644 --- a/source/mlir.js +++ b/source/mlir.js @@ -1193,7 +1193,6 @@ mlir.BytecodeReader = class { constructor(context) { this._reader = new mlir.BinaryReader(context); - this._decoder = new TextDecoder('utf-8'); } read() { @@ -1201,100 +1200,214 @@ mlir.BytecodeReader = class { reader.read(4); // signature 'ML\xEFR' this.version = reader.varint().toNumber(); this.producer = reader.string(); - this.sections = []; + this.sections = new Map(); while (reader.position < reader.length) { // https://mlir.llvm.org/docs/BytecodeFormat/ - const code = reader.byte(); - const identifier = code & 0x7F; + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bytecode/Reader/BytecodeReader.cpp + const sectionIDAndHasAlignment = reader.byte(); + const sectionID = sectionIDAndHasAlignment & 0x7F; const length = reader.varint().toNumber(); - if (code >> 7) { + const hasAlignment = sectionIDAndHasAlignment & 0x80; + if (sectionID >= 9) { + throw new mlir.Error(`Unsupported section identifier '${sectionID}'.`); + } + if (hasAlignment) { const alignment = reader.varint(); reader.skip(alignment); } - const next = reader.position + length; - switch (identifier) { - case 0: { // string - const lengths = new Array(reader.varint().toNumber()); - for (let i = 0; i < lengths.length; i++) { - lengths[i] = reader.varint().toNumber(); - } - this.strings = new Array(lengths.length); - for (let i = 0; i < this.strings.length; i++) { - const size = lengths[lengths.length - i - 1]; - const buffer = reader.read(size); - this.strings[i] = this._decoder.decode(buffer); - } - break; - } - case 1: { // dialect - const numDialects = reader.varint().toNumber(); - this.dialectNames = new Array(numDialects); - this.opNames = new Array(numDialects); - for (let i = 0; i < this.dialectNames.length; i++) { - const group = {}; - const nameAndIsVersioned = reader.varint(); - group.name = (nameAndIsVersioned >> 1n).toNumber(); - if (nameAndIsVersioned & 1n) { - const size = reader.varint().toNumber(); - group.version = reader.read(size); - } - this.dialectNames[i] = group; - } - for (let i = 0; i < this.opNames.length; i++) { - const dialect_ops_group = {}; - dialect_ops_group.dialect = reader.varint(); - dialect_ops_group.opNames = new Array(reader.varint().toNumber()); - for (let j = 0; j < dialect_ops_group.opNames.length; j++) { - const op_name_group = {}; - const nameAndIsRegistered = reader.varint(); - op_name_group.isRegistered = (nameAndIsRegistered & 1n) === 1n; - op_name_group.name = (nameAndIsRegistered >> 1n).toNumber(); - dialect_ops_group.opNames[j] = op_name_group; - } - this.opNames[i] = dialect_ops_group; - } - break; - } - case 2: // attrType - case 3: { // attrTypeOffset - /* - const numAttrs = reader.varint().toNumber(); - const numTypes = reader.varint().toNumber(); - for (let i = 0; i < (numAttrs + numTypes); i++) { + const offset = reader.position; + reader.skip(length); + this.sections.set(sectionID, { start: offset, end: reader.position }); + } + if (!this.sections.has(0) || !this.sections.has(1) || + !this.sections.has(2) || !this.sections.has(3) || + !this.sections.has(4) || (this.version >= 5 && !this.sections.has(8))) { + throw new mlir.Error('Missing required section.'); + } + this._parseStringSection(); + if (this.sections.has(8)) { + this._parsePropertiesSection(); + } + this._parseDialectSection(); + this._parseResourceSection(); + this._parseAttrTypeSection(); + } - } - break; - */ - this.attrType = reader.stream(length); - break; - } - case 4: { // IR - reader.skip(length); - break; - } - case 5: { // resource - this.resource = reader.stream(length); - break; - } - case 6: { // resourceOffset - reader.skip(length); - break; - } - case 7: { // dialectVersions - reader.skip(length); - break; + _parseStringSection() { + const section = this.sections.get(0); + const reader = this._reader; + reader.seek(section.start); + const lengths = new Array(reader.varint().toNumber()); + for (let i = 0; i < lengths.length; i++) { + lengths[i] = reader.varint().toNumber(); + } + const decoder = new TextDecoder('utf-8'); + this.strings = new Array(lengths.length); + for (let i = 0; i < this.strings.length; i++) { + const size = lengths[lengths.length - 1 - i]; + const buffer = reader.read(size); + this.strings[i] = decoder.decode(buffer); + } + if (reader.position !== section.end) { + throw new mlir.Error(`Invalid string section size.`); + } + } + + _parseDialectSection() { + const section = this.sections.get(1); + const reader = this._reader; + reader.seek(section.start); + const numDialects = reader.varint().toNumber(); + this.dialects = new Array(numDialects); + for (let i = 0; i < this.dialects.length; i++) { + this.dialects[i] = {}; + if (this.version < 1) { // kDialectVersioning + const entryIdx = reader.varint().toNumber(); + this.dialects[i].name = this.strings[entryIdx]; + continue; + } + const nameAndIsVersioned = reader.varint(); + const dialectNameIdx = (nameAndIsVersioned >> 1n).toNumber(); + this.dialects[i].name = this.strings[dialectNameIdx]; + if (nameAndIsVersioned & 1n) { + const size = reader.varint().toNumber(); + this.dialects[i].version = reader.read(size); + } + } + let numOps = -1; + this.opNames = []; + if (this.version > 4) { // kElideUnknownBlockArgLocation + numOps = reader.varint().toNumber(); + this.opNames = new Array(numOps); + } + let i = 0; + while (reader.position < section.end) { + const dialect = this.dialects[reader.varint().toNumber()]; + const numEntries = reader.varint().toNumber(); + for (let j = 0; j < numEntries; j++) { + const opName = {}; + if (this.version < 5) { // kNativePropertiesEncoding + opName.name = this.strings[reader.varint().toNumber()]; + opName.dialect = dialect; + } else { + const nameAndIsRegistered = reader.varint(); + opName.name = this.strings[(nameAndIsRegistered >> 1n).toNumber()]; + opName.dialect = dialect; + opName.isRegistered = (nameAndIsRegistered & 1n) === 1n; } - case 8: { // properties - reader.skip(length); - break; + if (numOps < 0) { + this.opNames.push(opName); + } else { + this.opNames[i++] = opName; } - default: { - throw new mlir.Error(`Unsupported section identifier '${identifier}'.`); + } + } + if (reader.position !== section.end) { + throw new mlir.Error(`Invalid dialect section size.`); + } + } + + _parseResourceSection() { + const section = this.sections.get(6); + const reader = this._reader; + reader.seek(section.start); + const numExternalResourceGroups = reader.varint().toNumber(); + if (numExternalResourceGroups > 0) { + throw new mlir.Error(`Unsupported resource section.`); + } + /* + for (let i = 0; i < numExternalResourceGroups; i++) { + const numResources = reader.varint().toNumber(); + for (let j = 0; j < numResources; j++) { + const resource = {}; + resource.key = this.strings[reader.varint().toNumber()]; + resource.offset = reader.varint().toNumber(); + resource.kind = reader.byte(); + } + } + */ + if (reader.position !== section.end) { + throw new mlir.Error(`Invalid dialect section size.`); + } + } + + _parseAttrTypeSection() { + const section = this.sections.get(3); + const reader = this._reader; + reader.seek(section.start); + this.attributes = new Array(reader.varint().toNumber()); + this.types = new Array(reader.varint().toNumber()); + let offset = 0; + const parseEntries = (range) => { + for (let i = 0; i < range.length;) { + const dialect = this.dialects[reader.varint().toNumber()]; + const numEntries = reader.varint().toNumber(); + for (let j = 0; j < numEntries; j++) { + const entry = {}; + const entrySizeWithFlag = reader.varint(); + entry.hasCustomEncoding = (entrySizeWithFlag & 1n) === 1n; + entry.size = (entrySizeWithFlag >> 1n).toNumber(); + entry.offset = offset; + entry.dialect = dialect; + offset += entry.size; + range[i++] = entry; } } - if (reader.position !== next) { - throw new mlir.Error('Invalid section length.'); + }; + parseEntries(this.attributes); + parseEntries(this.types); + if (reader.position !== section.end) { + throw new mlir.Error(`Invalid dialect section size.`); + } + offset = this.sections.get(2).start; + const parseCustomEntry = (entry, reader, entryType) => { + // throw new mlir.Error(`Unsupported custom encoding.`); + if (entryType === 'type') { + // debugger; + } else { + // debugger; } + }; + const parseAsmEntry = (entry, reader, entryType) => { + if (entryType === 'type') { + // debugger; + } else { + // debugger; + } + }; + const resolveEntries = (range, entryType) => { + for (const entry of this.attributes) { + reader.seek(offset + entry.offset); + if (entry.hasCustomEncoding) { + parseCustomEntry(entry, reader); + } else { + parseAsmEntry(entry, reader, entryType); + } + // if (reader.position !== (offset + entry.offset + entry.size)) { + // throw new mlir.Error(`Invalid '${entryType}' section size.`); + // } + // delete entry.offset; + // delete entry.size; + } + }; + resolveEntries(this.attributes, 'attribute'); + resolveEntries(this.types, 'type'); + } + + _parsePropertiesSection() { + const section = this.sections.get(8); + const reader = this._reader; + reader.seek(section.start); + const count = reader.varint().toNumber(); + const offsetTable = new Array(count); + for (let i = 0; i < offsetTable.length; i++) { + const offset = reader.position; + const size = reader.varint().toNumber(); + const data = reader.read(size); + offsetTable[i] = { offset, data }; + } + if (reader.position !== section.end) { + throw new mlir.Error(`Invalid properties section size.`); } } }; @@ -1317,6 +1430,10 @@ mlir.BinaryReader = class { this._reader.skip(length); } + seek(offset) { + this._reader.seek(offset); + } + read(length) { return this._reader.read(length); } @@ -1330,17 +1447,25 @@ mlir.BinaryReader = class { } varint() { - let value = 0n; - let shift = 0n; - for (let i = 0; i < 10 && this._reader.position < this._reader.length; i++) { - const byte = this._reader.byte(); - value |= BigInt(byte >> 1) << shift; - if ((byte & 1) === 1) { - return value; - } - shift += 7n; - } - throw new mlir.Error('Invalid varint value.'); + let result = this._reader.byte(); + if (result & 1) { + return BigInt(result >> 1); + } + if (result === 0) { + return this._reader.uint64(); + } + result = BigInt(result); + let mask = 1n; + let numBytes = 0n; + let shift = 8n; + while (result > 0n && (result & mask) === 0n) { + result |= (BigInt(this._reader.byte()) << shift); + mask <<= 1n; + shift += 8n; + numBytes++; + } + result >>= BigInt(numBytes + 1n); + return result; } string() { diff --git a/test/models.json b/test/models.json index 4a5ef2f9b6..ac1b56d18a 100644 --- a/test/models.json +++ b/test/models.json @@ -3113,7 +3113,7 @@ "target": "model.mlirbc", "source": "https://github.com/user-attachments/files/17179955/model.mlirbc.zip[model.mlirbc]", "format": "MLIR", - "error": "Invalid section length.", + "error": "Invalid content. File contains MLIR bytecode data.", "link": "https://github.com/lutzroeder/netron/issues/1044" }, {