aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-03-07 16:29:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-07 16:48:17 -0800
commitb69dd297e19af8514ac46c9e85a916c838222dbb (patch)
tree123ada26ef29c12a0ed70562da0846c0aa2e99ba /tensorflow
parent78756dcc25eb1224c12f78fe5adf79718c53683c (diff)
EstimatorSpec: Define default scaffold in case of None provided. This will reduce the clutter in the usage of EstimatorSpec. Forexample, We can check scaffold.saver instead of if scaffold and scaffold.saver...
Change: 149482422
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/estimator/estimator.py8
-rw-r--r--tensorflow/python/estimator/model_fn.py4
-rw-r--r--tensorflow/python/estimator/model_fn_test.py8
3 files changed, 14 insertions, 6 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index a52d3fb2e0..472e7a4390 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -416,8 +416,8 @@ class Estimator(object):
all_hooks.extend(hooks)
all_hooks.extend(estimator_spec.training_hooks)
- scaffold = estimator_spec.scaffold or training.Scaffold()
- if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
+ if not (estimator_spec.scaffold.saver or
+ ops.get_collection(ops.GraphKeys.SAVERS)):
ops.add_to_collection(ops.GraphKeys.SAVERS,
training.Saver(
sharded=True,
@@ -438,13 +438,13 @@ class Estimator(object):
self._model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
- scaffold=scaffold)
+ scaffold=estimator_spec.scaffold)
]
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
- scaffold=scaffold,
+ scaffold=estimator_spec.scaffold,
hooks=all_hooks,
chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 374ec6c8ca..a8936894ee 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -244,9 +244,9 @@ class EstimatorSpec(
'All hooks must be SessionRunHook instances, given: {}'.format(
hook))
+ scaffold = scaffold or monitored_session.Scaffold()
# Validate scaffold.
- if (scaffold is not None and
- not isinstance(scaffold, monitored_session.Scaffold)):
+ if not isinstance(scaffold, monitored_session.Scaffold):
raise TypeError(
'scaffold must be tf.train.Scaffold. Given: {}'.format(scaffold))
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py
index cf6608ae4c..29ca79ff51 100644
--- a/tensorflow/python/estimator/model_fn_test.py
+++ b/tensorflow/python/estimator/model_fn_test.py
@@ -182,6 +182,14 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=control_flow_ops.no_op(),
scaffold=_InvalidScaffold())
+ def testReturnDefaultScaffold(self):
+ with ops.Graph().as_default(), self.test_session():
+ estimator_spec = model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.FIT,
+ loss=constant_op.constant(1.),
+ train_op=control_flow_ops.no_op())
+ self.assertIsNotNone(estimator_spec.scaffold)
+
class EstimatorSpecEvalTest(test.TestCase):
"""Tests EstimatorSpec in eval mode."""