Skip to content

Commit

Permalink
Weighted metrics without metrics (#474)
Browse files Browse the repository at this point in the history
* Compile weighted_metrics even if metrics is None

Fixes #454

* Update trainer_test.py

* fixed formatting
  • Loading branch information
mihirparadkar authored Jul 14, 2023
1 parent 71527e9 commit f0c65c5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion keras_core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def compile(
)
else:
self._compile_loss = None
if metrics is not None:
if metrics is not None or weighted_metrics is not None:
self._compile_metrics = CompileMetrics(
metrics, weighted_metrics, output_names=output_names
)
Expand Down
16 changes: 16 additions & 0 deletions keras_core/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,22 @@ def __init__(self, units):
# And those weights are tracked at the model level
self.assertEqual(len(model.metrics_variables), 6)

# Models with only weighted_metrics should have the same 3 metrics
model_weighted = ModelWithMetric(units=3)
model_weighted.compile(
optimizer=optimizers.SGD(),
loss=losses.MeanSquaredError(),
weighted_metrics=[metrics.MeanSquaredError()],
)
model_weighted.fit(
x,
y,
batch_size=2,
epochs=1,
sample_weight=np.ones(2),
)
self.assertEqual(len(model_weighted.metrics), 3)

@parameterized.named_parameters(
[
("eager", True, False, False),
Expand Down

0 comments on commit f0c65c5

Please sign in to comment.