diff --git a/keras_tuner/engine/metrics_tracking.py b/keras_tuner/engine/metrics_tracking.py index 7c612a6aa..eb181297f 100644 --- a/keras_tuner/engine/metrics_tracking.py +++ b/keras_tuner/engine/metrics_tracking.py @@ -404,23 +404,21 @@ def to_proto(self) -> object: """Create proto from MetricsTracker instance.""" Mt = proto.MetricsTracker # type:ignore # noqa: PGH003 return Mt( - to_register=[ - { - name: self.metrics[name].to_proto() - for name in list(self.metrics.keys()) - } - ] + metrics={ + name: self.metrics[name].to_proto() + for name in list(self.metrics.keys()) + } ) @classmethod def from_proto(cls, proto: "MetricsTracker") -> "MetricsTracker": """Create a MetricsTracker instance from a proto.""" - metrics: _MetricNameToHistory = { - name: MetricHistory.from_proto(proto.metrics[name]) + metrics: _MetricsHistoriesInput = [ + {name: MetricHistory.from_proto(proto.metrics[name])} for name in list(proto.metrics.keys()) - } + ] - return cls(to_register=[metrics]) + return cls(to_register=metrics) def _assert_exists(self, name: str) -> None: """Ensure that name is a metric.""" diff --git a/keras_tuner/engine/metrics_tracking_test.py b/keras_tuner/engine/metrics_tracking_test.py index c140550fb..eb15a7a94 100644 --- a/keras_tuner/engine/metrics_tracking_test.py +++ b/keras_tuner/engine/metrics_tracking_test.py @@ -229,17 +229,16 @@ def test_metric_history_proto(): ] -def test_metricstracker_proto(): +def test_metrics_tracker_proto(): tracker = metrics_tracking.MetricsTracker() tracker.register("score", direction="max") - tracker.append_execution_value("score", value=10) - tracker.append_execution_value("score", value=20) + tracker.append_execution_value("score", value=[10, 20]) tracker.append_execution_value("score", value=30) proto = tracker.to_proto() - obs = proto.metrics["score"].get_history() - assert obs[0].value == [10, 20] - assert obs[1].value == [30] + executions = proto.metrics["score"].executions + assert executions[0].value == [10, 20] + assert executions[1].value == [30] assert proto.metrics["score"].direction new_tracker = metrics_tracking.MetricsTracker.from_proto(proto)