Skip to content

Commit

Permalink
reformat with black
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongyoonlee committed May 20, 2024
1 parent 9d8bba3 commit 3bbce11
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
8 changes: 7 additions & 1 deletion causalml/inference/tree/causal/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ def _support_missing_values(self, X) -> bool:
return False

def fit(
self, X, y, treatment, sample_weight=None, check_input=True, X_idx_sorted="deprecated"
self,
X,
y,
treatment,
sample_weight=None,
check_input=True,
X_idx_sorted="deprecated",
):
random_state = check_random_state(self.random_state)

Expand Down
16 changes: 14 additions & 2 deletions causalml/inference/tree/causal/causalforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,13 @@ def __init__(
self.alpha = alpha
self.groups_cnt = groups_cnt

def _fit(self, X: np.ndarray, treatment: np.ndarray, y: np.ndarray, sample_weight: np.ndarray = None):
def _fit(
self,
X: np.ndarray,
treatment: np.ndarray,
y: np.ndarray,
sample_weight: np.ndarray = None,
):
"""
Build a forest of trees from the training set (X, y).
With modified _parallel_build_trees for Causal Trees used in BaseForest.fit()
Expand Down Expand Up @@ -377,7 +383,13 @@ def _fit(self, X: np.ndarray, treatment: np.ndarray, y: np.ndarray, sample_weigh

return self

def fit(self, X: np.ndarray, treatment: np.ndarray, y: np.ndarray, sample_weight: np.ndarray = None):
def fit(
self,
X: np.ndarray,
treatment: np.ndarray,
y: np.ndarray,
sample_weight: np.ndarray = None,
):
"""
Fit Causal RandomForest
Args:
Expand Down
4 changes: 3 additions & 1 deletion causalml/inference/tree/causal/causaltree.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def fit(
X, y, w = self._prepare_data(X=X, y=y, treatment=treatment)
self.treatment_groups = np.unique(w)

super().fit(X=X, treatment=w, y=y, sample_weight=sample_weight, check_input=check_input)
super().fit(
X=X, treatment=w, y=y, sample_weight=sample_weight, check_input=check_input
)

if self.groups_cnt:
self._groups_cnt = self._count_groups_distribution(X=X, treatment=w)
Expand Down

0 comments on commit 3bbce11

Please sign in to comment.