diff options
author | Mustafa Ispir <ispir@google.com> | 2017-03-07 16:29:44 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-07 16:48:17 -0800 |
commit | b69dd297e19af8514ac46c9e85a916c838222dbb (patch) | |
tree | 123ada26ef29c12a0ed70562da0846c0aa2e99ba /tensorflow | |
parent | 78756dcc25eb1224c12f78fe5adf79718c53683c (diff) |
EstimatorSpec: Define default scaffold in case of None provided. This will reduce the clutter in the usage of EstimatorSpec. Forexample, We can check scaffold.saver instead of if scaffold and scaffold.saver...
Change: 149482422
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 8 | ||||
-rw-r--r-- | tensorflow/python/estimator/model_fn.py | 4 | ||||
-rw-r--r-- | tensorflow/python/estimator/model_fn_test.py | 8 |
3 files changed, 14 insertions, 6 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index a52d3fb2e0..472e7a4390 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -416,8 +416,8 @@ class Estimator(object): all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) - scaffold = estimator_spec.scaffold or training.Scaffold() - if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): + if not (estimator_spec.scaffold.saver or + ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection(ops.GraphKeys.SAVERS, training.Saver( sharded=True, @@ -438,13 +438,13 @@ class Estimator(object): self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, - scaffold=scaffold) + scaffold=estimator_spec.scaffold) ] with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, - scaffold=scaffold, + scaffold=estimator_spec.scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 374ec6c8ca..a8936894ee 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -244,9 +244,9 @@ class EstimatorSpec( 'All hooks must be SessionRunHook instances, given: {}'.format( hook)) + scaffold = scaffold or monitored_session.Scaffold() # Validate scaffold. - if (scaffold is not None and - not isinstance(scaffold, monitored_session.Scaffold)): + if not isinstance(scaffold, monitored_session.Scaffold): raise TypeError( 'scaffold must be tf.train.Scaffold. Given: {}'.format(scaffold)) diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index cf6608ae4c..29ca79ff51 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -182,6 +182,14 @@ class EstimatorSpecTrainTest(test.TestCase): train_op=control_flow_ops.no_op(), scaffold=_InvalidScaffold()) + def testReturnDefaultScaffold(self): + with ops.Graph().as_default(), self.test_session(): + estimator_spec = model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.FIT, + loss=constant_op.constant(1.), + train_op=control_flow_ops.no_op()) + self.assertIsNotNone(estimator_spec.scaffold) + class EstimatorSpecEvalTest(test.TestCase): """Tests EstimatorSpec in eval mode.""" |