aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2016-12-15 12:40:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-15 12:45:48 -0800
commit79a73505635e23da4a4cd66feebb4095abab1fe7 (patch)
tree76021c367e8bf4a62d57a9e34516c79d7eaa8b22
parent1e8b5eb885b6726fd72d5eec981af601b667fa80 (diff)
Adding Scaffold to the tf.learn.ModelFn, so that users can play with initialization, saving, and so on.
Change: 142175064
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py148
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py67
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/model_fn.py30
-rw-r--r--tensorflow/python/training/monitored_session.py16
4 files changed, 221 insertions, 40 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 2d40caa656..7bc8f80563 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -66,8 +66,12 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import device_setter
+from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import summary_io
from tensorflow.python.util import compat
@@ -1086,6 +1090,17 @@ class Estimator(BaseEstimator):
"""
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
+ # TODO(ispir): delete this function after converting all legacy usages.
+ def _call_legacy_get_train_ops(self, features, labels):
+ train_ops = self._get_train_ops(features, labels)
+ if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
+ return train_ops
+ return model_fn_lib.ModelFnOps(
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ predictions=None,
+ loss=train_ops[1],
+ train_op=train_ops[0])
+
def _get_eval_ops(self, features, labels, metrics):
"""Method that builds model graph and returns evaluation ops.
@@ -1241,6 +1256,119 @@ class Estimator(BaseEstimator):
return export_dir
+ @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y',
+ 'batch_size')
+ def fit(self,
+ x=None,
+ y=None,
+ input_fn=None,
+ steps=None,
+ batch_size=None,
+ monitors=None,
+ max_steps=None):
+ # pylint: disable=g-doc-args,g-doc-return-or-yield
+ """See `Trainable`.
+
+ Raises:
+ ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
+ ValueError: If both `steps` and `max_steps` are not `None`.
+ """
+ if (steps is not None) and (max_steps is not None):
+ raise ValueError('Can not provide both steps and max_steps.')
+ if max_steps is not None:
+ try:
+ start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
+ if max_steps <= start_step:
+ logging.info('Skipping training since max_steps has already saved.')
+ return None
+ except: # pylint: disable=bare-except
+ pass
+
+ hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
+ if steps is not None or max_steps is not None:
+ hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
+
+ input_fn, feed_fn = _get_input_fn(
+ x,
+ y,
+ input_fn,
+ feed_fn=None,
+ batch_size=batch_size,
+ shuffle=True,
+ epochs=None)
+ if feed_fn:
+ hooks.append(_FeedFnHook(feed_fn))
+ loss = self._train_model_v2(input_fn=input_fn, hooks=hooks)
+ logging.info('Loss for final step: %s.', loss)
+ return self
+
+ def _train_model_v2(self, input_fn, hooks):
+ all_hooks = []
+ self._graph = ops.Graph()
+ with self._graph.as_default() as g, g.device(self._device_fn):
+ random_seed.set_random_seed(self._config.tf_random_seed)
+ global_step = contrib_framework.create_global_step(g)
+ features, labels = input_fn()
+ self._check_inputs(features, labels)
+ model_fn_ops = self._call_legacy_get_train_ops(features, labels)
+ ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
+ all_hooks.extend([
+ basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
+ basic_session_run_hooks.LoggingTensorHook(
+ {
+ 'loss': model_fn_ops.loss,
+ 'step': global_step
+ },
+ every_n_iter=100)
+ ])
+ all_hooks.extend(hooks)
+
+ scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
+ if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
+ ops.add_to_collection(
+ ops.GraphKeys.SAVERS,
+ saver.Saver(
+ sharded=True,
+ max_to_keep=self._config.keep_checkpoint_max,
+ defer_build=True))
+
+ chief_hooks = []
+ if (self._config.save_checkpoints_secs or
+ self._config.save_checkpoints_steps):
+ chief_hooks = [
+ basic_session_run_hooks.CheckpointSaverHook(
+ self._model_dir,
+ save_secs=self._config.save_checkpoints_secs,
+ save_steps=self._config.save_checkpoints_steps,
+ scaffold=scaffold)
+ ]
+ with monitored_session.MonitoredTrainingSession(
+ master=self._config.master,
+ is_chief=self._config.is_chief,
+ checkpoint_dir=self._model_dir,
+ scaffold=scaffold,
+ hooks=all_hooks + model_fn_ops.training_hooks,
+ chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
+ save_checkpoint_secs=0, # Saving is handled by a hook.
+ save_summaries_steps=self._config.save_summary_steps,
+ config=None) as mon_sess:
+ loss = None
+ while not mon_sess.should_stop():
+ _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
+ summary_io.SummaryWriterCache.clear()
+ return loss
+
+
+class _FeedFnHook(session_run_hook.SessionRunHook):
+ """Runs feed_fn and sets the feed_dict accordingly."""
+
+ def __init__(self, feed_fn):
+ self.feed_fn = feed_fn
+
+ def before_run(self, run_context): # pylint: disable=unused-argument
+ return session_run_hook.SessionRunArgs(
+ fetches=None, feed_dict=self.feed_fn())
+
# For time of deprecation x,y from Estimator allow direct access.
# pylint: disable=protected-access
@@ -1252,19 +1380,19 @@ class SKCompat(sklearn.BaseEstimator):
def fit(self, x, y, batch_size=128, steps=None, max_steps=None,
monitors=None):
- if (steps is not None) and (max_steps is not None):
- raise ValueError('Can not provide both steps and max_steps.')
-
input_fn, feed_fn = _get_input_fn(x, y, input_fn=None, feed_fn=None,
batch_size=batch_size, shuffle=True,
epochs=None)
- loss = self._estimator._train_model(
- input_fn=input_fn,
- feed_fn=feed_fn,
- steps=steps,
- monitors=monitors,
- max_steps=max_steps)
- logging.info('Loss for final step: %s.', loss)
+ all_monitors = []
+ if feed_fn:
+ all_monitors = [_FeedFnHook(feed_fn)]
+ if monitors:
+ all_monitors.extend(monitors)
+
+ self._estimator.fit(input_fn=input_fn,
+ steps=steps,
+ max_steps=max_steps,
+ monitors=all_monitors)
return self
def score(self, x, y, batch_size=128, steps=None, metrics=None):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index c339a89faa..bb8b32d797 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -288,6 +288,26 @@ class EstimatorTest(tf.test.TestCase):
input_fn=functools.partial(boston_input_fn, num_epochs=1),
as_iterable=True)
+ def testModelFnScaffold(self):
+ self.is_init_fn_called = False
+
+ def _init_fn(scaffold, session):
+ _, _ = scaffold, session
+ self.is_init_fn_called = True
+
+ def _model_fn_scaffold(features, labels, mode):
+ _, _ = features, labels
+ return model_fn.ModelFnOps(
+ mode=mode,
+ predictions=tf.constant(0.),
+ loss=tf.constant(0.),
+ train_op=tf.constant(0.),
+ training_scaffold=tf.train.Scaffold(init_fn=_init_fn))
+
+ est = tf.contrib.learn.Estimator(model_fn=_model_fn_scaffold)
+ est.fit(input_fn=boston_input_fn, steps=1)
+ self.assertTrue(self.is_init_fn_called)
+
def testCustomConfig(self):
test_random_seed = 5783452
@@ -331,21 +351,40 @@ class EstimatorTest(tf.test.TestCase):
def testBadInput(self):
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
- self.assertRaisesRegexp(ValueError,
- 'Either x or input_fn must be provided.',
- est.fit, x=None, input_fn=None)
- self.assertRaisesRegexp(ValueError,
- 'Can not provide both input_fn and x or y',
- est.fit, x='X', input_fn=iris_input_fn)
- self.assertRaisesRegexp(ValueError,
- 'Can not provide both input_fn and x or y',
- est.fit, y='Y', input_fn=iris_input_fn)
- self.assertRaisesRegexp(ValueError,
- 'Can not provide both input_fn and batch_size',
- est.fit, input_fn=iris_input_fn, batch_size=100)
self.assertRaisesRegexp(
- ValueError, 'Inputs cannot be tensors. Please provide input_fn.',
- est.fit, x=tf.constant(1.))
+ ValueError,
+ 'Either x or input_fn must be provided.',
+ est.fit,
+ x=None,
+ input_fn=None,
+ steps=1)
+ self.assertRaisesRegexp(
+ ValueError,
+ 'Can not provide both input_fn and x or y',
+ est.fit,
+ x='X',
+ input_fn=iris_input_fn,
+ steps=1)
+ self.assertRaisesRegexp(
+ ValueError,
+ 'Can not provide both input_fn and x or y',
+ est.fit,
+ y='Y',
+ input_fn=iris_input_fn,
+ steps=1)
+ self.assertRaisesRegexp(
+ ValueError,
+ 'Can not provide both input_fn and batch_size',
+ est.fit,
+ input_fn=iris_input_fn,
+ batch_size=100,
+ steps=1)
+ self.assertRaisesRegexp(
+ ValueError,
+ 'Inputs cannot be tensors. Please provide input_fn.',
+ est.fit,
+ x=tf.constant(1.),
+ steps=1)
def testUntrained(self):
boston = tf.contrib.learn.datasets.load_boston()
diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
index 434b3105a0..49a9915fe8 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
@@ -48,18 +48,27 @@ class ModeKeys(object):
# TODO(roumposg): Pass output_signature_fn instead of signature_fn.
-class ModelFnOps(collections.namedtuple(
- 'ModelFnOps',
- ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn',
- 'output_alternatives', 'training_chief_hooks', 'training_hooks'])):
+class ModelFnOps(
+ collections.namedtuple('ModelFnOps', [
+ 'predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn',
+ 'output_alternatives', 'training_chief_hooks', 'training_hooks',
+ 'training_scaffold'
+ ])):
"""Ops returned from a model_fn."""
# TODO(soergel): remove signature_fn once sessionbundle export is deprecated.
- def __new__(cls, mode, predictions=None, loss=None, train_op=None,
- eval_metric_ops=None, signature_fn=None,
- output_alternatives=None, training_chief_hooks=None,
- training_hooks=None):
+ def __new__(cls,
+ mode,
+ predictions=None,
+ loss=None,
+ train_op=None,
+ eval_metric_ops=None,
+ signature_fn=None,
+ output_alternatives=None,
+ training_chief_hooks=None,
+ training_hooks=None,
+ training_scaffold=None):
"""Creates a validated `ModelFnOps` instance.
For a multi-headed model, the predictions dict here will contain the outputs
@@ -99,6 +108,8 @@ class ModelFnOps(collections.namedtuple(
run on the chief worker during training.
training_hooks: A list of `SessionRunHook` objects that will be run on
all workers during training.
+ training_scaffold: A `tf.train.Scaffold` object that can be used to set
+ initialization, saver, and more to be used in training.
Returns:
A validated `ModelFnOps` object.
@@ -169,4 +180,5 @@ class ModelFnOps(collections.namedtuple(
return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op,
eval_metric_ops, signature_fn,
output_alternatives,
- training_chief_hooks, training_hooks)
+ training_chief_hooks, training_hooks,
+ training_scaffold)
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index a195596b06..e6d550ab61 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -281,15 +281,15 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
Returns:
A `MonitoredSession` object.
"""
- hooks = hooks or []
scaffold = scaffold or Scaffold()
if not is_chief:
session_creator = WorkerSessionCreator(
scaffold=scaffold, master=master, config=config)
- return MonitoredSession(session_creator=session_creator, hooks=hooks)
+ return MonitoredSession(session_creator=session_creator, hooks=hooks or [])
+ all_hooks = []
if chief_only_hooks:
- hooks.extend(chief_only_hooks)
+ all_hooks.extend(chief_only_hooks)
session_creator = ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir,
@@ -297,19 +297,21 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
config=config)
if checkpoint_dir:
- hooks.append(
+ all_hooks.append(
basic_session_run_hooks.StepCounterHook(output_dir=checkpoint_dir))
if save_summaries_steps > 0:
- hooks.append(basic_session_run_hooks.SummarySaverHook(
+ all_hooks.append(basic_session_run_hooks.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
output_dir=checkpoint_dir))
if save_checkpoint_secs > 0:
- hooks.append(basic_session_run_hooks.CheckpointSaverHook(
+ all_hooks.append(basic_session_run_hooks.CheckpointSaverHook(
checkpoint_dir, save_secs=save_checkpoint_secs, scaffold=scaffold))
- return MonitoredSession(session_creator=session_creator, hooks=hooks)
+ if hooks:
+ all_hooks.extend(hooks)
+ return MonitoredSession(session_creator=session_creator, hooks=all_hooks)
class SessionCreator(object):