Skip to content

Commit

Permalink
signed
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Aug 1, 2023
1 parent 8a0a683 commit f4cc4e4
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions onnxmltools/convert/sparkml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def buildInitialTypesSimple(dataframe):


def getTensorTypeFromSpark(sparktype):
if sparktype == "StringType" or sparktype == "StringType()":
if sparktype in ("StringType", "StringType()"):
return StringTensorType([1, 1])
elif (
if (
sparktype == "DecimalType"
or sparktype == "DecimalType()"
or sparktype == "DoubleType"
Expand All @@ -34,17 +34,16 @@ def getTensorTypeFromSpark(sparktype):
or sparktype == "BooleanType"
or sparktype == "BooleanType()"
):
return FloatTensorType([1, 1])
else:
raise TypeError("Cannot map this type to Onnx types: " + sparktype)
return FloatTensorType([None, 1])
raise TypeError(f"Cannot map this type to Onnx types: {sparktype}.")


def buildInputDictSimple(dataframe):
import numpy

result = {}
for field in dataframe.schema.fields:
if str(field.dataType) == "StringType":
if str(field.dataType) in ("StringType", "StringType()"):
result[field.name] = dataframe.select(field.name).toPandas().values
else:
result[field.name] = (
Expand Down

0 comments on commit f4cc4e4

Please sign in to comment.