Skip to content

Commit

Permalink
Update TensorFlow script
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 8, 2023
1 parent 837c196 commit 58c224f
Showing 1 changed file with 54 additions and 52 deletions.
106 changes: 54 additions & 52 deletions tools/tf_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import json
import os
import google.protobuf # pylint: disable=import-error
from tensorflow.core.framework import api_def_pb2 # pylint: disable=import-error
from tensorflow.core.framework import op_def_pb2 # pylint: disable=import-error
from tensorflow.core.framework import types_pb2 # pylint: disable=import-error
from tensorflow.core.framework import api_def_pb2 # pylint: disable=import-error,no-name-in-module
from tensorflow.core.framework import op_def_pb2 # pylint: disable=import-error,no-name-in-module
from tensorflow.core.framework import types_pb2 # pylint: disable=import-error,no-name-in-module

def _read(path):
with open(path, 'r', encoding='utf-8') as file:
Expand Down Expand Up @@ -80,7 +80,7 @@ def _pbtxt_from_multiline(multiline_pbtxt):
return pbtxt

def _read_op_list(file):
op_list = op_def_pb2.OpList()
op_list = op_def_pb2.OpList() # pylint: disable=no-member
content = _read(file)
google.protobuf.text_format.Merge(content, op_list)
return op_list
Expand All @@ -89,7 +89,7 @@ def _read_api_def_map(folder):
api_def_map = {}
for filename in sorted(os.listdir(folder)):
if filename.endswith('.pbtxt'):
api_defs = api_def_pb2.ApiDefs()
api_defs = api_def_pb2.ApiDefs() # pylint: disable=no-member
filename = folder + '/' + filename
with open(filename, 'r', encoding='utf-8') as file:
multiline_pbtxt = file.read()
Expand Down Expand Up @@ -174,53 +174,55 @@ def _convert_attr_value(attr_value):
raise NotImplementedError()
return value

DataType = types_pb2.DataType # pylint: disable=no-member

type_to_string_map = {
types_pb2.DataType.DT_HALF: "float16",
types_pb2.DataType.DT_FLOAT: "float32",
types_pb2.DataType.DT_DOUBLE: "float64",
types_pb2.DataType.DT_INT32: "int32",
types_pb2.DataType.DT_UINT8: "uint8",
types_pb2.DataType.DT_UINT16: "uint16",
types_pb2.DataType.DT_UINT32: "uint32",
types_pb2.DataType.DT_UINT64: "uint64",
types_pb2.DataType.DT_INT16: "int16",
types_pb2.DataType.DT_INT8: "int8",
types_pb2.DataType.DT_STRING: "string",
types_pb2.DataType.DT_COMPLEX64: "complex64",
types_pb2.DataType.DT_COMPLEX128: "complex128",
types_pb2.DataType.DT_INT64: "int64",
types_pb2.DataType.DT_BOOL: "bool",
types_pb2.DataType.DT_QINT8: "qint8",
types_pb2.DataType.DT_QUINT8: "quint8",
types_pb2.DataType.DT_QINT16: "qint16",
types_pb2.DataType.DT_QUINT16: "quint16",
types_pb2.DataType.DT_QINT32: "qint32",
types_pb2.DataType.DT_BFLOAT16: "bfloat16",
types_pb2.DataType.DT_RESOURCE: "resource",
types_pb2.DataType.DT_VARIANT: "variant",
types_pb2.DataType.DT_HALF_REF: "float16_ref",
types_pb2.DataType.DT_FLOAT_REF: "float32_ref",
types_pb2.DataType.DT_DOUBLE_REF: "float64_ref",
types_pb2.DataType.DT_INT32_REF: "int32_ref",
types_pb2.DataType.DT_UINT32_REF: "uint32_ref",
types_pb2.DataType.DT_UINT8_REF: "uint8_ref",
types_pb2.DataType.DT_UINT16_REF: "uint16_ref",
types_pb2.DataType.DT_INT16_REF: "int16_ref",
types_pb2.DataType.DT_INT8_REF: "int8_ref",
types_pb2.DataType.DT_STRING_REF: "string_ref",
types_pb2.DataType.DT_COMPLEX64_REF: "complex64_ref",
types_pb2.DataType.DT_COMPLEX128_REF: "complex128_ref",
types_pb2.DataType.DT_INT64_REF: "int64_ref",
types_pb2.DataType.DT_UINT64_REF: "uint64_ref",
types_pb2.DataType.DT_BOOL_REF: "bool_ref",
types_pb2.DataType.DT_QINT8_REF: "qint8_ref",
types_pb2.DataType.DT_QUINT8_REF: "quint8_ref",
types_pb2.DataType.DT_QINT16_REF: "qint16_ref",
types_pb2.DataType.DT_QUINT16_REF: "quint16_ref",
types_pb2.DataType.DT_QINT32_REF: "qint32_ref",
types_pb2.DataType.DT_BFLOAT16_REF: "bfloat16_ref",
types_pb2.DataType.DT_RESOURCE_REF: "resource_ref",
types_pb2.DataType.DT_VARIANT_REF: "variant_ref",
DataType.DT_HALF: "float16",
DataType.DT_FLOAT: "float32",
DataType.DT_DOUBLE: "float64",
DataType.DT_INT32: "int32",
DataType.DT_UINT8: "uint8",
DataType.DT_UINT16: "uint16",
DataType.DT_UINT32: "uint32",
DataType.DT_UINT64: "uint64",
DataType.DT_INT16: "int16",
DataType.DT_INT8: "int8",
DataType.DT_STRING: "string",
DataType.DT_COMPLEX64: "complex64",
DataType.DT_COMPLEX128: "complex128",
DataType.DT_INT64: "int64",
DataType.DT_BOOL: "bool",
DataType.DT_QINT8: "qint8",
DataType.DT_QUINT8: "quint8",
DataType.DT_QINT16: "qint16",
DataType.DT_QUINT16: "quint16",
DataType.DT_QINT32: "qint32",
DataType.DT_BFLOAT16: "bfloat16",
DataType.DT_RESOURCE: "resource",
DataType.DT_VARIANT: "variant",
DataType.DT_HALF_REF: "float16_ref",
DataType.DT_FLOAT_REF: "float32_ref",
DataType.DT_DOUBLE_REF: "float64_ref",
DataType.DT_INT32_REF: "int32_ref",
DataType.DT_UINT32_REF: "uint32_ref",
DataType.DT_UINT8_REF: "uint8_ref",
DataType.DT_UINT16_REF: "uint16_ref",
DataType.DT_INT16_REF: "int16_ref",
DataType.DT_INT8_REF: "int8_ref",
DataType.DT_STRING_REF: "string_ref",
DataType.DT_COMPLEX64_REF: "complex64_ref",
DataType.DT_COMPLEX128_REF: "complex128_ref",
DataType.DT_INT64_REF: "int64_ref",
DataType.DT_UINT64_REF: "uint64_ref",
DataType.DT_BOOL_REF: "bool_ref",
DataType.DT_QINT8_REF: "qint8_ref",
DataType.DT_QUINT8_REF: "quint8_ref",
DataType.DT_QINT16_REF: "qint16_ref",
DataType.DT_QUINT16_REF: "quint16_ref",
DataType.DT_QINT32_REF: "qint32_ref",
DataType.DT_BFLOAT16_REF: "bfloat16_ref",
DataType.DT_RESOURCE_REF: "resource_ref",
DataType.DT_VARIANT_REF: "variant_ref",
}

def _format_data_type(data_type):
Expand Down Expand Up @@ -375,7 +377,7 @@ def _metadata():
json_schema['name'] = operator.name
if operator.name in categories:
json_schema['category'] = categories[operator.name]
api_def = api_def_pb2.ApiDef()
api_def = api_def_pb2.ApiDef() # pylint: disable=no-member
if operator.name in api_def_map:
api_def = api_def_map[operator.name]
if api_def.summary:
Expand Down

0 comments on commit 58c224f

Please sign in to comment.