Skip to content

Commit

Permalink
add option flags for shuffling feature order
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Aug 7, 2024
1 parent 058d66a commit d49a7db
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
2 changes: 0 additions & 2 deletions docs/benchmarks/ebm-benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
"source": [
"# install powerlift if not already installed\n",
"\n",
"# !! IMPORTANT !! : until the next release, install locally with \"pip install -e .[datasets,postgres]\" from powerlift directory\n",
"\n",
"try:\n",
" import powerlift\n",
"except ModuleNotFoundError:\n",
Expand Down
7 changes: 7 additions & 0 deletions python/interpret-core/interpret/develop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
_current_module = sys.modules[__name__]
_current_module.is_debug_mode = False

# Global options
_purify_boosting = False
_purify_result = False
_randomize_initial_feature_order = True
# TODO: investigate if _randomize_feature_order actually decreases accuracy
# https://github.com/interpretml/interpret/issues/563#issuecomment-2240820952
# this seems to decrease accuracy slightly, but helps with collinearity
_randomize_greedy_feature_order = True # randomize feature order only if greedy enabled
_randomize_feature_order = False # randomize feature order always


def print_debug_info(file=None):
Expand Down
11 changes: 9 additions & 2 deletions python/interpret-core/interpret/glassbox/_ebm/_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import heapq
from ...utils._native import Native
from ... import develop

import logging

Expand Down Expand Up @@ -74,7 +75,7 @@ def boost(
_log.info("Start boosting")
native = Native.get_native_singleton()
nominals = native.extract_nominals(dataset)
random_cyclic_ordering = np.empty(len(term_features), np.int64)
random_cyclic_ordering = np.arange(len(term_features), dtype=np.int64)

while step_idx < max_steps:
term_boost_flags_local = term_boost_flags
Expand All @@ -85,7 +86,13 @@ def boost(
bestkey = None
heap = []
# if pure cyclical then only randomize at start
if 0 < greedy_steps or step_idx == 0:
if (
step_idx == 0
and develop._randomize_initial_feature_order
or develop._randomize_greedy_feature_order
and 0 < greedy_steps
or develop._randomize_feature_order
):
# TODO: test if shuffling during pure cyclic is better
native.shuffle(rng, random_cyclic_ordering)

Expand Down

0 comments on commit d49a7db

Please sign in to comment.