diff options
author | Jianwei Xie <xiejw@google.com> | 2017-01-18 22:32:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-18 22:44:15 -0800 |
commit | 83be74027a9d0fef86deff97da6c7f6538c2a5aa (patch) | |
tree | 2287af91206d921473ea2db9f44e3b8b1fd6c553 | |
parent | 88a223c2b7f19e665fb4a94e1de544df3964761d (diff) |
Make train_monitors property getter returns shallow copy of the internal list and restrict the way to mutate it.
Change: 144924445
-rw-r--r-- | tensorflow/contrib/learn/python/learn/experiment.py | 15 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/experiment_test.py | 22 |
2 files changed, 30 insertions, 7 deletions
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 86a0a82cd1..ed0e546442 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -139,8 +139,8 @@ class Experiment(object): self._continuous_eval_throttle_secs = continuous_eval_throttle_secs self._min_eval_frequency = min_eval_frequency self._delay_workers_by_global_step = delay_workers_by_global_step + self._train_monitors = train_monitors or [] # Mutable fields, using the setters. - self.train_monitors = train_monitors self.eval_hooks = eval_hooks self.export_strategies = export_strategies self.continuous_eval_predicate_fn = continuous_eval_predicate_fn @@ -170,12 +170,9 @@ class Experiment(object): return self._eval_steps @property - def train_monitors(self): - return self._train_monitors - - @train_monitors.setter - def train_monitors(self, value): - self._train_monitors = value or [] + def train_hooks(self): + """Returns a shallow copy of train hooks for inspecting.""" + return [m for m in self._train_monitors] @property def eval_hooks(self): @@ -232,6 +229,10 @@ class Experiment(object): raise ValueError("`export_strategies` must be an ExportStrategy, " "a list of ExportStrategies, or None.") + def extend_train_hooks(self, additional_hooks): + """Extends the hooks for training.""" + self._train_monitors.extend(additional_hooks) + def train(self, delay_secs=None): """Fit the estimator using the training data. diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index 545ee38fad..096d334e8c 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -368,6 +368,28 @@ class ExperimentTest(test.TestCase): self.assertEquals([noop_hook], est.eval_hooks) self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor)) + def test_train_monitors_returns_shallow_copy(self): + noop_hook = _NoopHook() + ex = experiment.Experiment( + TestEstimator(), + train_input_fn='train_input', + eval_input_fn='eval_input', + eval_metrics='eval_metrics', + train_monitors=[noop_hook], + train_steps=100, + eval_steps=100, + local_eval_frequency=10) + self.assertAllEqual([noop_hook], ex.train_hooks) + + another_noop_hook = _NoopHook() + # Assert that the property getter returns a shallow copy. + ex.train_hooks.extend([another_noop_hook]) + self.assertAllEqual([noop_hook], ex.train_hooks) + + # Assert that the extend API mutates the monitors. + ex.extend_train_hooks([another_noop_hook]) + self.assertAllEqual([noop_hook, another_noop_hook], ex.train_hooks) + def test_train_and_evaluate(self): est = TestEstimator() noop_hook = _NoopHook() |