Skip to content

Commit

Permalink
Update mlir.js (#1044)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 1, 2024
1 parent aa41bf2 commit ef35b8b
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 95 deletions.
313 changes: 219 additions & 94 deletions source/mlir.js
Original file line number Diff line number Diff line change
Expand Up @@ -1193,108 +1193,221 @@ mlir.BytecodeReader = class {

constructor(context) {
this._reader = new mlir.BinaryReader(context);
this._decoder = new TextDecoder('utf-8');
}

read() {
const reader = this._reader;
reader.read(4); // signature 'ML\xEFR'
this.version = reader.varint().toNumber();
this.producer = reader.string();
this.sections = [];
this.sections = new Map();
while (reader.position < reader.length) {
// https://mlir.llvm.org/docs/BytecodeFormat/
const code = reader.byte();
const identifier = code & 0x7F;
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
const sectionIDAndHasAlignment = reader.byte();
const sectionID = sectionIDAndHasAlignment & 0x7F;
const length = reader.varint().toNumber();
if (code >> 7) {
const hasAlignment = sectionIDAndHasAlignment & 0x80;
if (sectionID >= 9) {
throw new mlir.Error(`Unsupported section identifier '${sectionID}'.`);
}
if (hasAlignment) {
const alignment = reader.varint();
reader.skip(alignment);
}
const next = reader.position + length;
switch (identifier) {
case 0: { // string
const lengths = new Array(reader.varint().toNumber());
for (let i = 0; i < lengths.length; i++) {
lengths[i] = reader.varint().toNumber();
}
this.strings = new Array(lengths.length);
for (let i = 0; i < this.strings.length; i++) {
const size = lengths[lengths.length - i - 1];
const buffer = reader.read(size);
this.strings[i] = this._decoder.decode(buffer);
}
break;
}
case 1: { // dialect
const numDialects = reader.varint().toNumber();
this.dialectNames = new Array(numDialects);
this.opNames = new Array(numDialects);
for (let i = 0; i < this.dialectNames.length; i++) {
const group = {};
const nameAndIsVersioned = reader.varint();
group.name = (nameAndIsVersioned >> 1n).toNumber();
if (nameAndIsVersioned & 1n) {
const size = reader.varint().toNumber();
group.version = reader.read(size);
}
this.dialectNames[i] = group;
}
for (let i = 0; i < this.opNames.length; i++) {
const dialect_ops_group = {};
dialect_ops_group.dialect = reader.varint();
dialect_ops_group.opNames = new Array(reader.varint().toNumber());
for (let j = 0; j < dialect_ops_group.opNames.length; j++) {
const op_name_group = {};
const nameAndIsRegistered = reader.varint();
op_name_group.isRegistered = (nameAndIsRegistered & 1n) === 1n;
op_name_group.name = (nameAndIsRegistered >> 1n).toNumber();
dialect_ops_group.opNames[j] = op_name_group;
}
this.opNames[i] = dialect_ops_group;
}
break;
}
case 2: // attrType
case 3: { // attrTypeOffset
/*
const numAttrs = reader.varint().toNumber();
const numTypes = reader.varint().toNumber();
for (let i = 0; i < (numAttrs + numTypes); i++) {
const offset = reader.position;
reader.skip(length);
this.sections.set(sectionID, { start: offset, end: reader.position });
}
if (!this.sections.has(0) || !this.sections.has(1) ||
!this.sections.has(2) || !this.sections.has(3) ||
!this.sections.has(4) || (this.version >= 5 && !this.sections.has(8))) {
throw new mlir.Error('Missing required section.');
}
this._parseStringSection();
if (this.sections.has(8)) {
this._parsePropertiesSection();
}
this._parseDialectSection();
this._parseResourceSection();
this._parseAttrTypeSection();
}

}
break;
*/
this.attrType = reader.stream(length);
break;
}
case 4: { // IR
reader.skip(length);
break;
}
case 5: { // resource
this.resource = reader.stream(length);
break;
}
case 6: { // resourceOffset
reader.skip(length);
break;
}
case 7: { // dialectVersions
reader.skip(length);
break;
_parseStringSection() {
const section = this.sections.get(0);
const reader = this._reader;
reader.seek(section.start);
const lengths = new Array(reader.varint().toNumber());
for (let i = 0; i < lengths.length; i++) {
lengths[i] = reader.varint().toNumber();
}
const decoder = new TextDecoder('utf-8');
this.strings = new Array(lengths.length);
for (let i = 0; i < this.strings.length; i++) {
const size = lengths[lengths.length - 1 - i];
const buffer = reader.read(size);
this.strings[i] = decoder.decode(buffer);
}
if (reader.position !== section.end) {
throw new mlir.Error(`Invalid string section size.`);
}
}

_parseDialectSection() {
const section = this.sections.get(1);
const reader = this._reader;
reader.seek(section.start);
const numDialects = reader.varint().toNumber();
this.dialects = new Array(numDialects);
for (let i = 0; i < this.dialects.length; i++) {
this.dialects[i] = {};
if (this.version < 1) { // kDialectVersioning
const entryIdx = reader.varint().toNumber();
this.dialects[i].name = this.strings[entryIdx];
continue;
}
const nameAndIsVersioned = reader.varint();
const dialectNameIdx = (nameAndIsVersioned >> 1n).toNumber();
this.dialects[i].name = this.strings[dialectNameIdx];
if (nameAndIsVersioned & 1n) {
const size = reader.varint().toNumber();
this.dialects[i].version = reader.read(size);
}
}
let numOps = -1;
this.opNames = [];
if (this.version > 4) { // kElideUnknownBlockArgLocation
numOps = reader.varint().toNumber();
this.opNames = new Array(numOps);
}
let i = 0;
while (reader.position < section.end) {
const dialect = this.dialects[reader.varint().toNumber()];
const numEntries = reader.varint().toNumber();
for (let j = 0; j < numEntries; j++) {
const opName = {};
if (this.version < 5) { // kNativePropertiesEncoding
opName.name = this.strings[reader.varint().toNumber()];
opName.dialect = dialect;
} else {
const nameAndIsRegistered = reader.varint();
opName.name = this.strings[(nameAndIsRegistered >> 1n).toNumber()];
opName.dialect = dialect;
opName.isRegistered = (nameAndIsRegistered & 1n) === 1n;
}
case 8: { // properties
reader.skip(length);
break;
if (numOps < 0) {
this.opNames.push(opName);
} else {
this.opNames[i++] = opName;
}
default: {
throw new mlir.Error(`Unsupported section identifier '${identifier}'.`);
}
}
if (reader.position !== section.end) {
throw new mlir.Error(`Invalid dialect section size.`);
}
}

_parseResourceSection() {
const section = this.sections.get(6);
const reader = this._reader;
reader.seek(section.start);
const numExternalResourceGroups = reader.varint().toNumber();
if (numExternalResourceGroups > 0) {
throw new mlir.Error(`Unsupported resource section.`);
}
/*
for (let i = 0; i < numExternalResourceGroups; i++) {
const numResources = reader.varint().toNumber();
for (let j = 0; j < numResources; j++) {
const resource = {};
resource.key = this.strings[reader.varint().toNumber()];
resource.offset = reader.varint().toNumber();
resource.kind = reader.byte();
}
}
*/
if (reader.position !== section.end) {
throw new mlir.Error(`Invalid dialect section size.`);
}
}

_parseAttrTypeSection() {
const section = this.sections.get(3);
const reader = this._reader;
reader.seek(section.start);
this.attributes = new Array(reader.varint().toNumber());
this.types = new Array(reader.varint().toNumber());
let offset = 0;
const parseEntries = (range) => {
for (let i = 0; i < range.length;) {
const dialect = this.dialects[reader.varint().toNumber()];
const numEntries = reader.varint().toNumber();
for (let j = 0; j < numEntries; j++) {
const entry = {};
const entrySizeWithFlag = reader.varint();
entry.hasCustomEncoding = (entrySizeWithFlag & 1n) === 1n;
entry.size = (entrySizeWithFlag >> 1n).toNumber();
entry.offset = offset;
entry.dialect = dialect;
offset += entry.size;
range[i++] = entry;
}
}
if (reader.position !== next) {
throw new mlir.Error('Invalid section length.');
};
parseEntries(this.attributes);
parseEntries(this.types);
if (reader.position !== section.end) {
throw new mlir.Error(`Invalid dialect section size.`);
}
offset = this.sections.get(2).start;
const parseCustomEntry = (entry, reader, entryType) => {
// throw new mlir.Error(`Unsupported custom encoding.`);
if (entryType === 'type') {
// debugger;
} else {
// debugger;
}
};
const parseAsmEntry = (entry, reader, entryType) => {
if (entryType === 'type') {
// debugger;
} else {
// debugger;
}
};
const resolveEntries = (range, entryType) => {
for (const entry of this.attributes) {
reader.seek(offset + entry.offset);
if (entry.hasCustomEncoding) {
parseCustomEntry(entry, reader);
} else {
parseAsmEntry(entry, reader, entryType);
}
// if (reader.position !== (offset + entry.offset + entry.size)) {
// throw new mlir.Error(`Invalid '${entryType}' section size.`);
// }
// delete entry.offset;
// delete entry.size;
}
};
resolveEntries(this.attributes, 'attribute');
resolveEntries(this.types, 'type');
}

_parsePropertiesSection() {
const section = this.sections.get(8);
const reader = this._reader;
reader.seek(section.start);
const count = reader.varint().toNumber();
const offsetTable = new Array(count);
for (let i = 0; i < offsetTable.length; i++) {
const offset = reader.position;
const size = reader.varint().toNumber();
const data = reader.read(size);
offsetTable[i] = { offset, data };
}
if (reader.position !== section.end) {
throw new mlir.Error(`Invalid properties section size.`);
}
}
};
Expand All @@ -1317,6 +1430,10 @@ mlir.BinaryReader = class {
this._reader.skip(length);
}

seek(offset) {
this._reader.seek(offset);
}

read(length) {
return this._reader.read(length);
}
Expand All @@ -1330,17 +1447,25 @@ mlir.BinaryReader = class {
}

varint() {
let value = 0n;
let shift = 0n;
for (let i = 0; i < 10 && this._reader.position < this._reader.length; i++) {
const byte = this._reader.byte();
value |= BigInt(byte >> 1) << shift;
if ((byte & 1) === 1) {
return value;
}
shift += 7n;
}
throw new mlir.Error('Invalid varint value.');
let result = this._reader.byte();
if (result & 1) {
return BigInt(result >> 1);
}
if (result === 0) {
return this._reader.uint64();
}
result = BigInt(result);
let mask = 1n;
let numBytes = 0n;
let shift = 8n;
while (result > 0n && (result & mask) === 0n) {
result |= (BigInt(this._reader.byte()) << shift);
mask <<= 1n;
shift += 8n;
numBytes++;
}
result >>= BigInt(numBytes + 1n);
return result;
}

string() {
Expand Down
2 changes: 1 addition & 1 deletion test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -3113,7 +3113,7 @@
"target": "model.mlirbc",
"source": "https://github.com/user-attachments/files/17179955/model.mlirbc.zip[model.mlirbc]",
"format": "MLIR",
"error": "Invalid section length.",
"error": "Invalid content. File contains MLIR bytecode data.",
"link": "https://github.com/lutzroeder/netron/issues/1044"
},
{
Expand Down

0 comments on commit ef35b8b

Please sign in to comment.