Skip to content

Commit

Permalink
types: metrics history typefix.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mah Neh committed Sep 26, 2024
1 parent 0315a92 commit 2461a03
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
18 changes: 8 additions & 10 deletions keras_tuner/engine/metrics_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,23 +404,21 @@ def to_proto(self) -> object:
"""Create proto from MetricsTracker instance."""
Mt = proto.MetricsTracker # type:ignore # noqa: PGH003
return Mt(
to_register=[
{
name: self.metrics[name].to_proto()
for name in list(self.metrics.keys())
}
]
metrics={
name: self.metrics[name].to_proto()
for name in list(self.metrics.keys())
}
)

@classmethod
def from_proto(cls, proto: "MetricsTracker") -> "MetricsTracker":
"""Create a MetricsTracker instance from a proto."""
metrics: _MetricNameToHistory = {
name: MetricHistory.from_proto(proto.metrics[name])
metrics: _MetricsHistoriesInput = [
{name: MetricHistory.from_proto(proto.metrics[name])}
for name in list(proto.metrics.keys())
}
]

return cls(to_register=[metrics])
return cls(to_register=metrics)

def _assert_exists(self, name: str) -> None:
"""Ensure that name is a metric."""
Expand Down
11 changes: 5 additions & 6 deletions keras_tuner/engine/metrics_tracking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,16 @@ def test_metric_history_proto():
]


def test_metricstracker_proto():
def test_metrics_tracker_proto():
tracker = metrics_tracking.MetricsTracker()
tracker.register("score", direction="max")
tracker.append_execution_value("score", value=10)
tracker.append_execution_value("score", value=20)
tracker.append_execution_value("score", value=[10, 20])
tracker.append_execution_value("score", value=30)

proto = tracker.to_proto()
obs = proto.metrics["score"].get_history()
assert obs[0].value == [10, 20]
assert obs[1].value == [30]
executions = proto.metrics["score"].executions
assert executions[0].value == [10, 20]
assert executions[1].value == [30]
assert proto.metrics["score"].direction

new_tracker = metrics_tracking.MetricsTracker.from_proto(proto)
Expand Down

0 comments on commit 2461a03

Please sign in to comment.