Skip to content

Commit

Permalink
Add missing dependency onnxconverter_common, fix multi regression wit…
Browse files Browse the repository at this point in the history
…h xgboost (#679)

* add missing dependency onnxconverter_common

Signed-off-by: Xavier Dupre <[email protected]>

* issue 676

Signed-off-by: Xavier Dupre <[email protected]>

* issue 676

Signed-off-by: Xavier Dupre <[email protected]>

* fix issue 676

Signed-off-by: Xavier Dupre <[email protected]>

* fix shape calcultator

Signed-off-by: Xavier Dupre <[email protected]>

* fix new name for sparse arguement

Signed-off-by: Xavier Dupre <[email protected]>

* changelogs

Signed-off-by: Xavier Dupre <[email protected]>

* improves stability

Signed-off-by: Xavier Dupre <[email protected]>

---------

Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre authored Jan 24, 2024
1 parent 180e733 commit eb21c0e
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 3 deletions.
7 changes: 7 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Change Logs

## 1.13.0 (development)

* Add missing dependency onnxconverter_common, fix multi regression with xgboost,
[#679](https://github.com/onnx/onnxmltools/pull/679),
fixes issues [No module named 'onnxconverter_common'](https://github.com/onnx/onnxmltools/issues/673),
[onnx converted : xgboostRegressor multioutput model predicts 1 dimension instead of original 210 dimensions.](https://github.com/onnx/onnxmltools/issues/676)

## 1.12.0

* Fix early stopping for XGBClassifier and xgboost > 2
Expand Down
2 changes: 1 addition & 1 deletion onnxmltools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
This framework converts any machine learned model into onnx format
which is a common language to describe any machine learned model.
"""
__version__ = "1.12.0"
__version__ = "1.13.0"
__author__ = "ONNX"
__producer__ = "OnnxMLTools"
__producer_version__ = __version__
Expand Down
6 changes: 6 additions & 0 deletions onnxmltools/convert/xgboost/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def get_xgb_params(xgb_node):
bs = float(config["learner"]["learner_model_param"]["base_score"])
# xgboost >= 2.0
params["base_score"] = bs
if "num_target" in config["learner"]["learner_model_param"]:
params["n_targets"] = int(
config["learner"]["learner_model_param"]["num_target"]
)
else:
params["n_targets"] = 1

bst = xgb_node.get_booster()
if hasattr(bst, "best_ntree_limit"):
Expand Down
3 changes: 3 additions & 0 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ def convert(scope, operator, container):
js_trees, attr_pairs, [1 for _ in js_trees], False
)

params = XGBConverter.get_xgb_params(xgb_node)
attr_pairs["n_targets"] = params["n_targets"]

# add nodes
if objective == "count:poisson":
names = [scope.get_unique_variable_name("tree")]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
numpy
onnx
onnxconverter_common
2 changes: 1 addition & 1 deletion tests/xgboost/test_xgboost_converters_base_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_xgbclassifier_sparse_base_score(self):
assert_almost_equal(expected.reshape((-1, 2)), got, decimal=4)

def test_xgbclassifier_sparse_no_base_score(self):
X, y = make_regression(n_samples=200, n_features=10, random_state=0)
X, y = make_regression(n_samples=400, n_features=10, random_state=0)
mask = np.random.randint(0, 50, size=(X.shape)) != 0
X[mask] = 0
y = (y + mask.sum(axis=1, keepdims=0)).astype(np.float32)
Expand Down
55 changes: 55 additions & 0 deletions tests/xgboost/test_xgboost_issues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0

import unittest


class TestXGBoostIssues(unittest.TestCase):
def test_issue_676(self):
import json
import onnxruntime
import xgboost
import numpy as np
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import update_registered_converter
from onnxmltools.convert.xgboost.operator_converters.XGBoost import (
convert_xgboost,
)

def xgbregressor_shape_calculator(operator):
config = json.loads(operator.raw_operator.get_booster().save_config())
n_targets = int(config["learner"]["learner_model_param"]["num_target"])
operator.outputs[0].type.shape = [None, n_targets]

update_registered_converter(
xgboost.XGBRegressor,
"XGBoostXGBRegressor",
xgbregressor_shape_calculator,
convert_xgboost,
)
# Your data and labels
X = np.random.rand(100, 10)
y = np.random.rand(100, 2)

# Train XGBoost regressor
model = xgboost.XGBRegressor(
objective="reg:squarederror", n_estimators=2, maxdepth=2
)
model.fit(X, y)

# Define input type (adjust shape according to your input)
initial_type = [("float_input", FloatTensorType([None, X.shape[1]]))]

# Convert XGBoost model to ONNX
onnx_model = convert_sklearn(model, initial_types=initial_type, target_opset=12)
self.assertIn("dim_value: 2", str(onnx_model.graph.output))

sess = onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
got = sess.run(None, {"float_input": X.astype(np.float32)})
self.assertEqual(got[0].shape, (100, 2))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/xgboost/test_xgboost_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def transformer_for_column(column):
if column.dtype in ["bool"]:
return "passthrough"
if column.dtype in ["O"]:
return OneHotEncoder(sparse=False)
return OneHotEncoder(sparse_output=False)
raise ValueError()

return ColumnTransformer(
Expand Down

0 comments on commit eb21c0e

Please sign in to comment.