aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-01-18 22:32:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-18 22:44:15 -0800
commit83be74027a9d0fef86deff97da6c7f6538c2a5aa (patch)
tree2287af91206d921473ea2db9f44e3b8b1fd6c553
parent88a223c2b7f19e665fb4a94e1de544df3964761d (diff)
Make train_monitors property getter returns shallow copy of the internal list and restrict the way to mutate it.
Change: 144924445
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py15
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py22
2 files changed, 30 insertions, 7 deletions
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 86a0a82cd1..ed0e546442 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -139,8 +139,8 @@ class Experiment(object):
self._continuous_eval_throttle_secs = continuous_eval_throttle_secs
self._min_eval_frequency = min_eval_frequency
self._delay_workers_by_global_step = delay_workers_by_global_step
+ self._train_monitors = train_monitors or []
# Mutable fields, using the setters.
- self.train_monitors = train_monitors
self.eval_hooks = eval_hooks
self.export_strategies = export_strategies
self.continuous_eval_predicate_fn = continuous_eval_predicate_fn
@@ -170,12 +170,9 @@ class Experiment(object):
return self._eval_steps
@property
- def train_monitors(self):
- return self._train_monitors
-
- @train_monitors.setter
- def train_monitors(self, value):
- self._train_monitors = value or []
+ def train_hooks(self):
+ """Returns a shallow copy of train hooks for inspecting."""
+ return [m for m in self._train_monitors]
@property
def eval_hooks(self):
@@ -232,6 +229,10 @@ class Experiment(object):
raise ValueError("`export_strategies` must be an ExportStrategy, "
"a list of ExportStrategies, or None.")
+ def extend_train_hooks(self, additional_hooks):
+ """Extends the hooks for training."""
+ self._train_monitors.extend(additional_hooks)
+
def train(self, delay_secs=None):
"""Fit the estimator using the training data.
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index 545ee38fad..096d334e8c 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -368,6 +368,28 @@ class ExperimentTest(test.TestCase):
self.assertEquals([noop_hook], est.eval_hooks)
self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor))
+ def test_train_monitors_returns_shallow_copy(self):
+ noop_hook = _NoopHook()
+ ex = experiment.Experiment(
+ TestEstimator(),
+ train_input_fn='train_input',
+ eval_input_fn='eval_input',
+ eval_metrics='eval_metrics',
+ train_monitors=[noop_hook],
+ train_steps=100,
+ eval_steps=100,
+ local_eval_frequency=10)
+ self.assertAllEqual([noop_hook], ex.train_hooks)
+
+ another_noop_hook = _NoopHook()
+ # Assert that the property getter returns a shallow copy.
+ ex.train_hooks.extend([another_noop_hook])
+ self.assertAllEqual([noop_hook], ex.train_hooks)
+
+ # Assert that the extend API mutates the monitors.
+ ex.extend_train_hooks([another_noop_hook])
+ self.assertAllEqual([noop_hook, another_noop_hook], ex.train_hooks)
+
def test_train_and_evaluate(self):
est = TestEstimator()
noop_hook = _NoopHook()