diff options
author | Mustafa Ispir <ispir@google.com> | 2016-12-15 12:40:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-15 12:45:48 -0800 |
commit | 79a73505635e23da4a4cd66feebb4095abab1fe7 (patch) | |
tree | 76021c367e8bf4a62d57a9e34516c79d7eaa8b22 | |
parent | 1e8b5eb885b6726fd72d5eec981af601b667fa80 (diff) |
Adding Scaffold to the tf.learn.ModelFn, so that users can play with initialization, saving, and so on.
Change: 142175064
4 files changed, 221 insertions, 40 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 2d40caa656..7bc8f80563 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -66,8 +66,12 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import device_setter +from tensorflow.python.training import monitored_session from tensorflow.python.training import saver +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import summary_io from tensorflow.python.util import compat @@ -1086,6 +1090,17 @@ class Estimator(BaseEstimator): """ return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN) + # TODO(ispir): delete this function after converting all legacy usages. + def _call_legacy_get_train_ops(self, features, labels): + train_ops = self._get_train_ops(features, labels) + if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature + return train_ops + return model_fn_lib.ModelFnOps( + mode=model_fn_lib.ModeKeys.TRAIN, + predictions=None, + loss=train_ops[1], + train_op=train_ops[0]) + def _get_eval_ops(self, features, labels, metrics): """Method that builds model graph and returns evaluation ops. @@ -1241,6 +1256,119 @@ class Estimator(BaseEstimator): return export_dir + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', + 'batch_size') + def fit(self, + x=None, + y=None, + input_fn=None, + steps=None, + batch_size=None, + monitors=None, + max_steps=None): + # pylint: disable=g-doc-args,g-doc-return-or-yield + """See `Trainable`. + + Raises: + ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`. + ValueError: If both `steps` and `max_steps` are not `None`. + """ + if (steps is not None) and (max_steps is not None): + raise ValueError('Can not provide both steps and max_steps.') + if max_steps is not None: + try: + start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP) + if max_steps <= start_step: + logging.info('Skipping training since max_steps has already saved.') + return None + except: # pylint: disable=bare-except + pass + + hooks = monitor_lib.replace_monitors_with_hooks(monitors, self) + if steps is not None or max_steps is not None: + hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps)) + + input_fn, feed_fn = _get_input_fn( + x, + y, + input_fn, + feed_fn=None, + batch_size=batch_size, + shuffle=True, + epochs=None) + if feed_fn: + hooks.append(_FeedFnHook(feed_fn)) + loss = self._train_model_v2(input_fn=input_fn, hooks=hooks) + logging.info('Loss for final step: %s.', loss) + return self + + def _train_model_v2(self, input_fn, hooks): + all_hooks = [] + self._graph = ops.Graph() + with self._graph.as_default() as g, g.device(self._device_fn): + random_seed.set_random_seed(self._config.tf_random_seed) + global_step = contrib_framework.create_global_step(g) + features, labels = input_fn() + self._check_inputs(features, labels) + model_fn_ops = self._call_legacy_get_train_ops(features, labels) + ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss) + all_hooks.extend([ + basic_session_run_hooks.NanTensorHook(model_fn_ops.loss), + basic_session_run_hooks.LoggingTensorHook( + { + 'loss': model_fn_ops.loss, + 'step': global_step + }, + every_n_iter=100) + ]) + all_hooks.extend(hooks) + + scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold() + if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): + ops.add_to_collection( + ops.GraphKeys.SAVERS, + saver.Saver( + sharded=True, + max_to_keep=self._config.keep_checkpoint_max, + defer_build=True)) + + chief_hooks = [] + if (self._config.save_checkpoints_secs or + self._config.save_checkpoints_steps): + chief_hooks = [ + basic_session_run_hooks.CheckpointSaverHook( + self._model_dir, + save_secs=self._config.save_checkpoints_secs, + save_steps=self._config.save_checkpoints_steps, + scaffold=scaffold) + ] + with monitored_session.MonitoredTrainingSession( + master=self._config.master, + is_chief=self._config.is_chief, + checkpoint_dir=self._model_dir, + scaffold=scaffold, + hooks=all_hooks + model_fn_ops.training_hooks, + chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks, + save_checkpoint_secs=0, # Saving is handled by a hook. + save_summaries_steps=self._config.save_summary_steps, + config=None) as mon_sess: + loss = None + while not mon_sess.should_stop(): + _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss]) + summary_io.SummaryWriterCache.clear() + return loss + + +class _FeedFnHook(session_run_hook.SessionRunHook): + """Runs feed_fn and sets the feed_dict accordingly.""" + + def __init__(self, feed_fn): + self.feed_fn = feed_fn + + def before_run(self, run_context): # pylint: disable=unused-argument + return session_run_hook.SessionRunArgs( + fetches=None, feed_dict=self.feed_fn()) + # For time of deprecation x,y from Estimator allow direct access. # pylint: disable=protected-access @@ -1252,19 +1380,19 @@ class SKCompat(sklearn.BaseEstimator): def fit(self, x, y, batch_size=128, steps=None, max_steps=None, monitors=None): - if (steps is not None) and (max_steps is not None): - raise ValueError('Can not provide both steps and max_steps.') - input_fn, feed_fn = _get_input_fn(x, y, input_fn=None, feed_fn=None, batch_size=batch_size, shuffle=True, epochs=None) - loss = self._estimator._train_model( - input_fn=input_fn, - feed_fn=feed_fn, - steps=steps, - monitors=monitors, - max_steps=max_steps) - logging.info('Loss for final step: %s.', loss) + all_monitors = [] + if feed_fn: + all_monitors = [_FeedFnHook(feed_fn)] + if monitors: + all_monitors.extend(monitors) + + self._estimator.fit(input_fn=input_fn, + steps=steps, + max_steps=max_steps, + monitors=all_monitors) return self def score(self, x, y, batch_size=128, steps=None, metrics=None): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index c339a89faa..bb8b32d797 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -288,6 +288,26 @@ class EstimatorTest(tf.test.TestCase): input_fn=functools.partial(boston_input_fn, num_epochs=1), as_iterable=True) + def testModelFnScaffold(self): + self.is_init_fn_called = False + + def _init_fn(scaffold, session): + _, _ = scaffold, session + self.is_init_fn_called = True + + def _model_fn_scaffold(features, labels, mode): + _, _ = features, labels + return model_fn.ModelFnOps( + mode=mode, + predictions=tf.constant(0.), + loss=tf.constant(0.), + train_op=tf.constant(0.), + training_scaffold=tf.train.Scaffold(init_fn=_init_fn)) + + est = tf.contrib.learn.Estimator(model_fn=_model_fn_scaffold) + est.fit(input_fn=boston_input_fn, steps=1) + self.assertTrue(self.is_init_fn_called) + def testCustomConfig(self): test_random_seed = 5783452 @@ -331,21 +351,40 @@ class EstimatorTest(tf.test.TestCase): def testBadInput(self): est = tf.contrib.learn.Estimator(model_fn=linear_model_fn) - self.assertRaisesRegexp(ValueError, - 'Either x or input_fn must be provided.', - est.fit, x=None, input_fn=None) - self.assertRaisesRegexp(ValueError, - 'Can not provide both input_fn and x or y', - est.fit, x='X', input_fn=iris_input_fn) - self.assertRaisesRegexp(ValueError, - 'Can not provide both input_fn and x or y', - est.fit, y='Y', input_fn=iris_input_fn) - self.assertRaisesRegexp(ValueError, - 'Can not provide both input_fn and batch_size', - est.fit, input_fn=iris_input_fn, batch_size=100) self.assertRaisesRegexp( - ValueError, 'Inputs cannot be tensors. Please provide input_fn.', - est.fit, x=tf.constant(1.)) + ValueError, + 'Either x or input_fn must be provided.', + est.fit, + x=None, + input_fn=None, + steps=1) + self.assertRaisesRegexp( + ValueError, + 'Can not provide both input_fn and x or y', + est.fit, + x='X', + input_fn=iris_input_fn, + steps=1) + self.assertRaisesRegexp( + ValueError, + 'Can not provide both input_fn and x or y', + est.fit, + y='Y', + input_fn=iris_input_fn, + steps=1) + self.assertRaisesRegexp( + ValueError, + 'Can not provide both input_fn and batch_size', + est.fit, + input_fn=iris_input_fn, + batch_size=100, + steps=1) + self.assertRaisesRegexp( + ValueError, + 'Inputs cannot be tensors. Please provide input_fn.', + est.fit, + x=tf.constant(1.), + steps=1) def testUntrained(self): boston = tf.contrib.learn.datasets.load_boston() diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 434b3105a0..49a9915fe8 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -48,18 +48,27 @@ class ModeKeys(object): # TODO(roumposg): Pass output_signature_fn instead of signature_fn. -class ModelFnOps(collections.namedtuple( - 'ModelFnOps', - ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn', - 'output_alternatives', 'training_chief_hooks', 'training_hooks'])): +class ModelFnOps( + collections.namedtuple('ModelFnOps', [ + 'predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn', + 'output_alternatives', 'training_chief_hooks', 'training_hooks', + 'training_scaffold' + ])): """Ops returned from a model_fn.""" # TODO(soergel): remove signature_fn once sessionbundle export is deprecated. - def __new__(cls, mode, predictions=None, loss=None, train_op=None, - eval_metric_ops=None, signature_fn=None, - output_alternatives=None, training_chief_hooks=None, - training_hooks=None): + def __new__(cls, + mode, + predictions=None, + loss=None, + train_op=None, + eval_metric_ops=None, + signature_fn=None, + output_alternatives=None, + training_chief_hooks=None, + training_hooks=None, + training_scaffold=None): """Creates a validated `ModelFnOps` instance. For a multi-headed model, the predictions dict here will contain the outputs @@ -99,6 +108,8 @@ class ModelFnOps(collections.namedtuple( run on the chief worker during training. training_hooks: A list of `SessionRunHook` objects that will be run on all workers during training. + training_scaffold: A `tf.train.Scaffold` object that can be used to set + initialization, saver, and more to be used in training. Returns: A validated `ModelFnOps` object. @@ -169,4 +180,5 @@ class ModelFnOps(collections.namedtuple( return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op, eval_metric_ops, signature_fn, output_alternatives, - training_chief_hooks, training_hooks) + training_chief_hooks, training_hooks, + training_scaffold) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index a195596b06..e6d550ab61 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -281,15 +281,15 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name Returns: A `MonitoredSession` object. """ - hooks = hooks or [] scaffold = scaffold or Scaffold() if not is_chief: session_creator = WorkerSessionCreator( scaffold=scaffold, master=master, config=config) - return MonitoredSession(session_creator=session_creator, hooks=hooks) + return MonitoredSession(session_creator=session_creator, hooks=hooks or []) + all_hooks = [] if chief_only_hooks: - hooks.extend(chief_only_hooks) + all_hooks.extend(chief_only_hooks) session_creator = ChiefSessionCreator( scaffold=scaffold, checkpoint_dir=checkpoint_dir, @@ -297,19 +297,21 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name config=config) if checkpoint_dir: - hooks.append( + all_hooks.append( basic_session_run_hooks.StepCounterHook(output_dir=checkpoint_dir)) if save_summaries_steps > 0: - hooks.append(basic_session_run_hooks.SummarySaverHook( + all_hooks.append(basic_session_run_hooks.SummarySaverHook( scaffold=scaffold, save_steps=save_summaries_steps, output_dir=checkpoint_dir)) if save_checkpoint_secs > 0: - hooks.append(basic_session_run_hooks.CheckpointSaverHook( + all_hooks.append(basic_session_run_hooks.CheckpointSaverHook( checkpoint_dir, save_secs=save_checkpoint_secs, scaffold=scaffold)) - return MonitoredSession(session_creator=session_creator, hooks=hooks) + if hooks: + all_hooks.extend(hooks) + return MonitoredSession(session_creator=session_creator, hooks=all_hooks) class SessionCreator(object): |