diff --git a/causalml/inference/tree/_tree/_classes.py b/causalml/inference/tree/_tree/_classes.py index 70031b6d..6dab4af5 100644 --- a/causalml/inference/tree/_tree/_classes.py +++ b/causalml/inference/tree/_tree/_classes.py @@ -89,7 +89,13 @@ def __init__( @abstractmethod def fit( - self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated" + self, + X, + treatment, + y, + sample_weight=None, + check_input=True, + X_idx_sorted="deprecated", ): pass diff --git a/causalml/inference/tree/_tree/_criterion.pxd b/causalml/inference/tree/_tree/_criterion.pxd index 762e3509..83ced966 100644 --- a/causalml/inference/tree/_tree/_criterion.pxd +++ b/causalml/inference/tree/_tree/_criterion.pxd @@ -29,6 +29,7 @@ cdef class Criterion: # Internal structures cdef const DOUBLE_t[:, ::1] y # Values of y + cdef DOUBLE_t* treatment # Treatment assignment cdef DOUBLE_t* sample_weight # Sample weights cdef SIZE_t* samples # Sample indices in X, y @@ -56,7 +57,7 @@ cdef class Criterion: # statistics correspond to samples[start:pos] and samples[pos:end]. # Methods - cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight, + cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* treatment, DOUBLE_t* sample_weight, double weighted_n_samples, SIZE_t* samples, SIZE_t start, SIZE_t end) nogil except -1 cdef int reset(self) nogil except -1 diff --git a/causalml/inference/tree/_tree/_criterion.pyx b/causalml/inference/tree/_tree/_criterion.pyx index ae281edd..5d6f61a1 100755 --- a/causalml/inference/tree/_tree/_criterion.pyx +++ b/causalml/inference/tree/_tree/_criterion.pyx @@ -48,7 +48,7 @@ cdef class Criterion: def __setstate__(self, d): pass - cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight, + cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* treatment, DOUBLE_t* sample_weight, double weighted_n_samples, SIZE_t* samples, SIZE_t start, SIZE_t end) nogil except -1: """Placeholder for a method which will initialize the criterion. @@ -60,6 +60,8 @@ cdef class Criterion: ---------- y : array-like, dtype=DOUBLE_t y is a buffer that can store values for n_outputs target variables + treatment : array-like, dtype=DOUBLE_t + The treatment assignment of each sample. sample_weight : array-like, dtype=DOUBLE_t The weight of each sample weighted_n_samples : double @@ -224,6 +226,7 @@ cdef class RegressionCriterion(Criterion): The total number of samples to fit on """ # Default values + self.treatment = NULL self.sample_weight = NULL self.samples = NULL @@ -259,7 +262,7 @@ cdef class RegressionCriterion(Criterion): def __reduce__(self): return (type(self), (self.n_outputs, self.n_samples), self.__getstate__()) - cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight, + cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* treatment, DOUBLE_t* sample_weight, double weighted_n_samples, SIZE_t* samples, SIZE_t start, SIZE_t end) nogil except -1: """Initialize the criterion. @@ -269,6 +272,7 @@ cdef class RegressionCriterion(Criterion): """ # Initialize fields self.y = y + self.treatment = treatment self.sample_weight = sample_weight self.samples = samples self.start = start diff --git a/causalml/inference/tree/_tree/_splitter.pxd b/causalml/inference/tree/_tree/_splitter.pxd index 5b63fc1d..41854b47 100644 --- a/causalml/inference/tree/_tree/_splitter.pxd +++ b/causalml/inference/tree/_tree/_splitter.pxd @@ -63,6 +63,7 @@ cdef class Splitter: cdef SIZE_t end # End position for the current node cdef const DOUBLE_t[:, ::1] y + cdef DOUBLE_t* treatment cdef DOUBLE_t* sample_weight # The samples vector `samples` is maintained by the Splitter object such @@ -83,7 +84,7 @@ cdef class Splitter: # Methods cdef int init(self, object X, const DOUBLE_t[:, ::1] y, - DOUBLE_t* sample_weight) except -1 + DOUBLE_t* treatment, DOUBLE_t* sample_weight) except -1 cdef int node_reset(self, SIZE_t start, SIZE_t end, double* weighted_n_node_samples) nogil except -1 diff --git a/causalml/inference/tree/_tree/_splitter.pyx b/causalml/inference/tree/_tree/_splitter.pyx index f72bcce9..22683ef9 100644 --- a/causalml/inference/tree/_tree/_splitter.pyx +++ b/causalml/inference/tree/_tree/_splitter.pyx @@ -94,6 +94,7 @@ cdef class Splitter: self.n_features = 0 self.feature_values = NULL + self.treatment = NULL self.sample_weight = NULL self.max_features = max_features @@ -118,6 +119,7 @@ cdef class Splitter: cdef int init(self, object X, const DOUBLE_t[:, ::1] y, + DOUBLE_t* treatment, DOUBLE_t* sample_weight) except -1: """Initialize the splitter. @@ -134,6 +136,9 @@ cdef class Splitter: y : ndarray, dtype=DOUBLE_t This is the vector of targets, or true labels, for the samples + treatment : DOUBLE_t* + The treatment assignments of the samples. + sample_weight : DOUBLE_t* The weights of the samples, where higher weighted samples are fit closer than lower weight samples. If not provided, all samples @@ -180,6 +185,7 @@ cdef class Splitter: self.y = y self.sample_weight = sample_weight + self.treatment = treatment return 0 cdef int node_reset(self, SIZE_t start, SIZE_t end, @@ -203,6 +209,7 @@ cdef class Splitter: self.end = end self.criterion.init(self.y, + self.treatment, self.sample_weight, self.weighted_n_samples, self.samples, @@ -243,6 +250,7 @@ cdef class BaseDenseSplitter(Splitter): cdef int init(self, object X, const DOUBLE_t[:, ::1] y, + DOUBLE_t* treatment, DOUBLE_t* sample_weight) except -1: """Initialize the splitter @@ -251,7 +259,7 @@ cdef class BaseDenseSplitter(Splitter): """ # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, treatment, sample_weight) self.X = X return 0 @@ -802,6 +810,7 @@ cdef class BaseSparseSplitter(Splitter): cdef int init(self, object X, const DOUBLE_t[:, ::1] y, + DOUBLE_t* treatment, DOUBLE_t* sample_weight) except -1: """Initialize the splitter @@ -809,7 +818,7 @@ cdef class BaseSparseSplitter(Splitter): or 0 otherwise. """ # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, treatment, sample_weight) if not isinstance(X, csc_matrix): raise ValueError("X should be in csc format") diff --git a/causalml/inference/tree/_tree/_tree.pxd b/causalml/inference/tree/_tree/_tree.pxd index d32b5eae..7e514c5e 100644 --- a/causalml/inference/tree/_tree/_tree.pxd +++ b/causalml/inference/tree/_tree/_tree.pxd @@ -103,6 +103,7 @@ cdef class TreeBuilder: Tree tree, object X, cnp.ndarray y, + cnp.ndarray treatment, cnp.ndarray sample_weight=*, ) @@ -110,5 +111,6 @@ cdef class TreeBuilder: self, object X, cnp.ndarray y, + cnp.ndarray treatment, cnp.ndarray sample_weight, ) diff --git a/causalml/inference/tree/_tree/_tree.pyx b/causalml/inference/tree/_tree/_tree.pyx index fbe11454..5202e9a4 100755 --- a/causalml/inference/tree/_tree/_tree.pyx +++ b/causalml/inference/tree/_tree/_tree.pyx @@ -97,11 +97,13 @@ cdef class TreeBuilder: """Interface for different tree building strategies.""" cpdef build(self, Tree tree, object X, cnp.ndarray y, + cnp.ndarray treatment, cnp.ndarray sample_weight=None): """Build a decision tree from the training set (X, y).""" pass cdef inline _check_input(self, object X, cnp.ndarray y, + cnp.ndarray treatment, cnp.ndarray sample_weight): """Check input dtype, layout and format""" if issparse(X): @@ -122,13 +124,16 @@ cdef class TreeBuilder: if y.dtype != DOUBLE or not y.flags.contiguous: y = np.ascontiguousarray(y, dtype=DOUBLE) + if treatment.dtype != DOUBLE or not treatment.flags.contiguous: + treatment = np.ascontiguousarray(treatment, dtype=DOUBLE) + if (sample_weight is not None and (sample_weight.dtype != DOUBLE or not sample_weight.flags.contiguous)): sample_weight = np.asarray(sample_weight, dtype=DOUBLE, order="C") - return X, y, sample_weight + return X, y, treatment, sample_weight # Depth first builder --------------------------------------------------------- @@ -146,12 +151,14 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): self.min_impurity_decrease = min_impurity_decrease cpdef build(self, Tree tree, object X, cnp.ndarray y, + cnp.ndarray treatment, cnp.ndarray sample_weight=None): """Build a decision tree from the training set (X, y).""" # check input - X, y, sample_weight = self._check_input(X, y, sample_weight) + X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight) + cdef DOUBLE_t* treatment_ptr = treatment.data cdef DOUBLE_t* sample_weight_ptr = NULL if sample_weight is not None: sample_weight_ptr = sample_weight.data @@ -175,7 +182,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double min_impurity_decrease = self.min_impurity_decrease # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr) + splitter.init(X, y, treatment_ptr, sample_weight_ptr) cdef SIZE_t start cdef SIZE_t end @@ -328,12 +335,14 @@ cdef class BestFirstTreeBuilder(TreeBuilder): self.min_impurity_decrease = min_impurity_decrease cpdef build(self, Tree tree, object X, cnp.ndarray y, + cnp.ndarray treatment, cnp.ndarray sample_weight=None): """Build a decision tree from the training set (X, y).""" # check input - X, y, sample_weight = self._check_input(X, y, sample_weight) + X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight) + cdef DOUBLE_t* treatment_ptr = treatment.data cdef DOUBLE_t* sample_weight_ptr = NULL if sample_weight is not None: sample_weight_ptr = sample_weight.data @@ -346,7 +355,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr) + splitter.init(X, y, treatment_ptr, sample_weight_ptr) cdef vector[FrontierRecord] frontier cdef FrontierRecord record diff --git a/causalml/inference/tree/causal/_builder.pyx b/causalml/inference/tree/causal/_builder.pyx index 8bac93e4..cacd5ec2 100755 --- a/causalml/inference/tree/causal/_builder.pyx +++ b/causalml/inference/tree/causal/_builder.pyx @@ -51,12 +51,14 @@ cdef class DepthFirstCausalTreeBuilder(TreeBuilder): self.min_impurity_decrease = min_impurity_decrease cpdef build(self, Tree tree, object X, np.ndarray y, + np.ndarray treatment, np.ndarray sample_weight=None): """Build a decision tree from the training set (X, y).""" # check input - X, y, sample_weight = self._check_input(X, y, sample_weight) + X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight) + cdef DOUBLE_t* treatment_ptr = treatment.data cdef DOUBLE_t* sample_weight_ptr = NULL if sample_weight is not None: sample_weight_ptr = sample_weight.data @@ -80,7 +82,7 @@ cdef class DepthFirstCausalTreeBuilder(TreeBuilder): cdef double min_impurity_decrease = self.min_impurity_decrease # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr) + splitter.init(X, y, treatment_ptr, sample_weight_ptr) cdef SIZE_t start cdef SIZE_t end @@ -239,13 +241,20 @@ cdef class BestFirstCausalTreeBuilder(TreeBuilder): self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease - cpdef build(self, Tree tree, object X, np.ndarray y, - np.ndarray sample_weight=None): + cpdef build( + self, + Tree tree, + object X, + np.ndarray y, + np.ndarray treatment, + np.ndarray sample_weight=None + ): """Build a decision tree from the training set (X, y).""" # check input - X, y, sample_weight = self._check_input(X, y, sample_weight) + X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight) + cdef DOUBLE_t* treatment_ptr = treatment.data cdef DOUBLE_t* sample_weight_ptr = NULL if sample_weight is not None: sample_weight_ptr = sample_weight.data @@ -258,7 +267,7 @@ cdef class BestFirstCausalTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr) + splitter.init(X, y, treatment_ptr, sample_weight_ptr) cdef vector[FrontierRecord] frontier cdef FrontierRecord record diff --git a/causalml/inference/tree/causal/_criterion.pyx b/causalml/inference/tree/causal/_criterion.pyx index 4ab9366f..f80f9dee 100755 --- a/causalml/inference/tree/causal/_criterion.pyx +++ b/causalml/inference/tree/causal/_criterion.pyx @@ -15,17 +15,24 @@ cdef class CausalRegressionCriterion(RegressionCriterion): """ cdef public SplitState state cdef public double groups_penalty - cdef public double eps - cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight, - double weighted_n_samples, SIZE_t* samples, SIZE_t start, - SIZE_t end) nogil except -1: + cdef int init( + self, + const DOUBLE_t[:, ::1] y, + DOUBLE_t* treatment, + DOUBLE_t* sample_weight, + double weighted_n_samples, + SIZE_t* samples, + SIZE_t start, + SIZE_t end + ) nogil except -1: """Initialize the criterion. This initializes the criterion at node samples[start:end] and children samples[start:start] and samples[start:end]. """ # Initialize fields self.y = y + self.treatment = treatment self.sample_weight = sample_weight self.samples = samples self.start = start @@ -43,14 +50,13 @@ cdef class CausalRegressionCriterion(RegressionCriterion): memset(&self.sum_total[0], 0, self.n_outputs * sizeof(double)) self.sq_sum_total = 0. - self.eps = 1e-5 self.state.node = [0., 0., 0., 0., 0., 0., 0., 0., 1.] self.state.left = [0., 0., 0., 0., 0., 0., 0., 0., 1.] self.state.right = [0., 0., 0., 0., 0., 0., 0., 0., 1.] for p in range(start, end): i = samples[p] - is_treated = sample_weight[i] - self.eps + is_treated = treatment[i] self.sum_total[k] += self.y[i, k] self.sq_sum_total += self.y[i, k] * self.y[i, k] @@ -139,6 +145,7 @@ cdef class CausalRegressionCriterion(RegressionCriterion): cdef int update(self, SIZE_t new_pos) nogil except -1: """Updated statistics by moving samples[pos:new_pos] to the left.""" cdef double * sample_weight = self.sample_weight + cdef double * treatment = self.treatment cdef SIZE_t * samples = self.samples cdef SIZE_t pos = self.pos @@ -159,7 +166,7 @@ cdef class CausalRegressionCriterion(RegressionCriterion): if (new_pos - pos) <= (end - new_pos): for p in range(pos, new_pos): i = samples[p] - is_treated = sample_weight[i] - self.eps + is_treated = treatment[i] self.sum_left[k] += self.y[i, k] self.state.left.tr_y_sum += is_treated * self.y[i, k] @@ -175,7 +182,7 @@ cdef class CausalRegressionCriterion(RegressionCriterion): for p in range(end - 1, new_pos - 1, -1): i = samples[p] - is_treated = sample_weight[i] - self.eps + is_treated = treatment[i] self.sum_left[k] -= self.y[i, k] self.state.left.tr_y_sum -= is_treated * self.y[i, k] @@ -267,8 +274,11 @@ cdef class StandardMSE(CausalRegressionCriterion): return (proxy_impurity_left / self.weighted_n_left + proxy_impurity_right / self.weighted_n_right) - cdef void children_impurity(self, double * impurity_left, - double * impurity_right) nogil: + cdef void children_impurity( + self, + double * impurity_left, + double * impurity_right + ) nogil: """Evaluate the impurity in children nodes. i.e. the impurity of the left child (samples[start:pos]) and the impurity the right child (samples[pos:end]). diff --git a/causalml/inference/tree/causal/_tree.py b/causalml/inference/tree/causal/_tree.py index df6c1f5d..77b06979 100755 --- a/causalml/inference/tree/causal/_tree.py +++ b/causalml/inference/tree/causal/_tree.py @@ -42,7 +42,13 @@ def _support_missing_values(self, X) -> bool: return False def fit( - self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated" + self, + X, + treatment, + y, + sample_weight=None, + check_input=True, + X_idx_sorted="deprecated", ): random_state = check_random_state(self.random_state) @@ -83,7 +89,7 @@ def fit( y = np.reshape(y, (-1, 1)) # For memory allocation to store control, treatment outcomes - self.n_outputs_ = np.unique(sample_weight).astype(int).size + self.n_outputs_ = np.unique(treatment).astype(int).size if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: y = np.ascontiguousarray(y, dtype=DOUBLE) @@ -204,7 +210,6 @@ def fit( criterion = CAUSAL_TREES_CRITERIA[self.criterion]( self.n_outputs_, n_samples ) - criterion.eps = self.eps criterion.groups_penalty = self.groups_penalty else: # Make a deepcopy in case the criterion has mutable attributes that @@ -249,7 +254,7 @@ def fit( self.min_impurity_decrease, ) - builder.build(self.tree_, X, y, sample_weight) + builder.build(self.tree_, X, y, treatment, sample_weight) self._prune_tree() diff --git a/causalml/inference/tree/causal/causalforest.py b/causalml/inference/tree/causal/causalforest.py index 77992ced..6b071826 100755 --- a/causalml/inference/tree/causal/causalforest.py +++ b/causalml/inference/tree/causal/causalforest.py @@ -33,6 +33,7 @@ def _parallel_build_trees( tree, forest, X, + treatment, y, sample_weight, tree_idx, @@ -48,11 +49,16 @@ def _parallel_build_trees( if forest.bootstrap: n_samples = X.shape[0] + if sample_weight is None: + curr_sample_weight = np.ones((n_samples,), dtype=np.float64) + else: + curr_sample_weight = sample_weight.copy() + indices = _generate_sample_indices( tree.random_state, n_samples, n_samples_bootstrap ) - X, y = X[indices].copy(), y[indices].copy() - curr_sample_weight = sample_weight[indices].copy() + sample_counts = np.bincount(indices, minlength=n_samples) + curr_sample_weight *= sample_counts if class_weight == "subsample": with catch_warnings(): @@ -61,9 +67,9 @@ def _parallel_build_trees( elif class_weight == "balanced_subsample": curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices) - tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False) + tree.fit(X, treatment, y, sample_weight=curr_sample_weight, check_input=False) else: - tree.fit(X, y, sample_weight=sample_weight, check_input=False) + tree.fit(X, treatment, y, sample_weight=sample_weight, check_input=False) return tree @@ -199,7 +205,13 @@ def __init__( self.alpha = alpha self.groups_cnt = groups_cnt - def _fit(self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray = None): + def _fit( + self, + X: np.ndarray, + treatment: np.ndarray, + y: np.ndarray, + sample_weight: np.ndarray = None, + ): """ Build a forest of trees from the training set (X, y). With modified _parallel_build_trees for Causal Trees used in BaseForest.fit() @@ -212,6 +224,9 @@ def _fit(self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray = None): to ``dtype=np.float32``. If a sparse matrix is provided, it will be converted into a sparse ``csc_matrix``. + treatment : array-like of shape (n_samples,) + The treatment assignments. + y : array-like of shape (n_samples,) or (n_samples, n_outputs) The target values (class labels in classification, real numbers in regression). @@ -267,8 +282,7 @@ def _fit(self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray = None): "is necessary for Poisson regression." ) - self.n_outputs_ = np.unique(sample_weight).astype(int).size + 1 - self.max_outputs_ = self.n_outputs_ + self.max_outputs_ = np.unique(treatment).astype(int).size + 1 y, expanded_class_weight = self._validate_y_class_weight(y) if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: @@ -335,13 +349,14 @@ def _fit(self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray = None): **_joblib_parallel_args, )( delayed(_parallel_build_trees)( - t, - self, - X, - y, - sample_weight, - i, - len(trees), + tree=t, + forest=self, + X=X, + treatment=treatment, + y=y, + sample_weight=sample_weight, + tree_idx=i, + n_trees=len(trees), verbose=self.verbose, class_weight=self.class_weight, n_samples_bootstrap=n_samples_bootstrap, @@ -368,18 +383,25 @@ def _fit(self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray = None): return self - def fit(self, X: np.ndarray, treatment: np.ndarray, y: np.ndarray): + def fit( + self, + X: np.ndarray, + treatment: np.ndarray, + y: np.ndarray, + sample_weight: np.ndarray = None, + ): """ Fit Causal RandomForest Args: X: (np.ndarray), feature matrix treatment: (np.ndarray), treatment vector y: (np.ndarray), outcome vector + sample_weight: (np.ndarray), sample weights Returns: self """ - X, y, w = self._estimator._prepare_data(X=X, y=y, treatment=treatment) - return self._fit(X=X, y=y, sample_weight=w) + X, y, w = self._estimator._prepare_data(X=X, treatment=treatment, y=y) + return self._fit(X=X, treatment=w, y=y, sample_weight=sample_weight) def predict(self, X: np.ndarray, with_outcomes: bool = False) -> np.ndarray: """Predict individual treatment effects diff --git a/causalml/inference/tree/causal/causaltree.py b/causalml/inference/tree/causal/causaltree.py index f37941e9..533b2b6d 100755 --- a/causalml/inference/tree/causal/causaltree.py +++ b/causalml/inference/tree/causal/causaltree.py @@ -129,7 +129,6 @@ def __init__( self.min_samples_leaf = min_samples_leaf self.random_state = random_state - self.eps = 1e-5 self._classes = {} self.groups_cnt = groups_cnt self.groups_cnt_mode = groups_cnt_mode @@ -153,19 +152,19 @@ def __init__( def fit( self, X: np.ndarray, + treatment: np.ndarray, y: np.ndarray, - treatment: np.ndarray = None, sample_weight: np.ndarray = None, check_input=False, ): """ Fit CausalTreeRegressor Args: - X: : (np.ndarray), feature matrix - y: : (np.ndarray), outcome vector - treatment: : (np.ndarray), treatment vector - sample_weight: (np.ndarray), sample_weight - check_input: (bool) + X (np.ndarray): feature matrix + treatment (np.ndarray): treatment vector + y (np.ndarray): outcome vector + sample_weight (np.ndarray): sample_weight + check_input (bool, optional): default=False Returns: self """ @@ -177,16 +176,12 @@ def fit( "min_impurity_decrease must be set to -inf for causal_mse criterion" ) - if treatment is None and sample_weight is None: - raise ValueError("`treatment` or `sample_weight` must be provided") - - if treatment is None: - X, y, w = X, y, sample_weight - else: - X, y, w = self._prepare_data(X=X, y=y, treatment=treatment) + X, y, w = self._prepare_data(X=X, y=y, treatment=treatment) self.treatment_groups = np.unique(w) - super().fit(X=X, y=y, sample_weight=self.eps + w, check_input=check_input) + super().fit( + X=X, treatment=w, y=y, sample_weight=sample_weight, check_input=check_input + ) if self.groups_cnt: self._groups_cnt = self._count_groups_distribution(X=X, treatment=w) diff --git a/tests/test_causal_trees.py b/tests/test_causal_trees.py index c407faf8..d05c99d6 100644 --- a/tests/test_causal_trees.py +++ b/tests/test_causal_trees.py @@ -115,7 +115,7 @@ def test_fit_predict( def test_predict(self, generate_regression_data): y, X, treatment, tau, b, e = generate_regression_data(mode=2) ctree = self.prepare_model() - ctree.fit(X=X, y=y, treatment=treatment) + ctree.fit(X=X, treatment=treatment, y=y) y_pred = ctree.predict(X[:1, :]) y_pred_with_outcomes = ctree.predict(X[:1, :], with_outcomes=True) assert y_pred.shape == (1,)