aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-01-10 14:12:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-10 14:24:45 -0800
commit61a6797c4f2e205cc8338789ea5cea129ead7cbf (patch)
treec77e8e70db1b0eeaa36e6629f7515704aa955661
parent3e59f0540ede856294eba374cd3d00231d90d5c9 (diff)
Simplified estimator logic by MonitoredSession.
Removed graph_action usage. Change: 144126485
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm.py102
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py565
2 files changed, 319 insertions, 348 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py
index dd7e9a3455..86450d4bbd 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm.py
@@ -21,20 +21,28 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
+
import numpy as np
from tensorflow.contrib import framework
from tensorflow.contrib.factorization.python.ops import gmm_ops
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.ops import variables
-from tensorflow.contrib.learn.python.learn.estimators import estimator
+from tensorflow.contrib.learn.python.learn import graph_actions
+from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
+from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib
+from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed as random_seed_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops.control_flow_ops import with_dependencies
+from tensorflow.python.platform import tf_logging as logging
def _streaming_sum(scalar_tensor):
@@ -44,7 +52,7 @@ def _streaming_sum(scalar_tensor):
return sum_metric, sum_update
-class GMM(estimator.Estimator, TransformerMixin):
+class GMM(estimator_lib.Estimator, TransformerMixin):
"""GMM clustering."""
SCORES = 'scores'
ASSIGNMENTS = 'assignments'
@@ -116,7 +124,8 @@ class GMM(estimator.Estimator, TransformerMixin):
self._data_feeder = data_feeder.setup_train_data_feeder(x, None,
self._num_clusters,
self.batch_size)
- self._train_model(
+ _legacy_train_model( # pylint: disable=protected-access
+ self,
input_fn=self._data_feeder.input_builder,
feed_fn=self._data_feeder.get_feed_dict_fn(),
steps=steps or self.steps,
@@ -218,3 +227,90 @@ class GMM(estimator.Estimator, TransformerMixin):
self._covariance_type,
self._params)
return {GMM.SCORES: _streaming_sum(math_ops.reduce_sum(losses))}
+
+
+# TODO(xavigonzalvo): delete this after implementing model-fn based Estimator.
+def _legacy_train_model(estimator,
+ input_fn,
+ steps,
+ feed_fn=None,
+ init_op=None,
+ init_feed_fn=None,
+ init_fn=None,
+ device_fn=None,
+ monitors=None,
+ log_every_steps=100,
+ fail_on_nan_loss=True,
+ max_steps=None):
+ """Legacy train function of Estimator."""
+ if hasattr(estimator.config, 'execution_mode'):
+ if estimator.config.execution_mode not in ('all', 'train'):
+ return
+
+ # Stagger startup of worker sessions based on task id.
+ sleep_secs = min(
+ estimator.config.training_worker_max_startup_secs,
+ estimator.config.task_id *
+ estimator.config.training_worker_session_startup_stagger_secs)
+ if sleep_secs:
+ logging.info('Waiting %d secs before starting task %d.', sleep_secs,
+ estimator.config.task_id)
+ time.sleep(sleep_secs)
+
+ # Device allocation
+ device_fn = device_fn or estimator._device_fn # pylint: disable=protected-access
+
+ with ops.Graph().as_default() as g, g.device(device_fn):
+ random_seed_lib.set_random_seed(estimator.config.tf_random_seed)
+ global_step = framework.create_global_step(g)
+ features, labels = input_fn()
+ estimator._check_inputs(features, labels) # pylint: disable=protected-access
+
+ # The default return type of _get_train_ops is ModelFnOps. But there are
+ # some subclasses of tf.contrib.learn.Estimator which override this
+ # method and use the legacy signature, namely _get_train_ops returns a
+ # (train_op, loss) tuple. The following else-statement code covers these
+ # cases, but will soon be deleted after the subclasses are updated.
+ # TODO(b/32664904): Update subclasses and delete the else-statement.
+ train_ops = estimator._get_train_ops(features, labels) # pylint: disable=protected-access
+ if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
+ train_op = train_ops.train_op
+ loss_op = train_ops.loss
+ if estimator.config.is_chief:
+ hooks = train_ops.training_chief_hooks + train_ops.training_hooks
+ else:
+ hooks = train_ops.training_hooks
+ else: # Legacy signature
+ if len(train_ops) != 2:
+ raise ValueError('Expected a tuple of train_op and loss, got {}'.format(
+ train_ops))
+ train_op = train_ops[0]
+ loss_op = train_ops[1]
+ hooks = []
+
+ hooks += monitor_lib.replace_monitors_with_hooks(monitors, estimator)
+
+ ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
+ return graph_actions._monitored_train( # pylint: disable=protected-access
+ graph=g,
+ output_dir=estimator.model_dir,
+ train_op=train_op,
+ loss_op=loss_op,
+ global_step_tensor=global_step,
+ init_op=init_op,
+ init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
+ init_fn=init_fn,
+ log_every_steps=log_every_steps,
+ supervisor_is_chief=estimator.config.is_chief,
+ supervisor_master=estimator.config.master,
+ supervisor_save_model_secs=estimator.config.save_checkpoints_secs,
+ supervisor_save_model_steps=estimator.config.save_checkpoints_steps,
+ supervisor_save_summaries_steps=estimator.config.save_summary_steps,
+ keep_checkpoint_max=estimator.config.keep_checkpoint_max,
+ keep_checkpoint_every_n_hours=(
+ estimator.config.keep_checkpoint_every_n_hours),
+ feed_fn=feed_fn,
+ steps=steps,
+ fail_on_nan_loss=fail_on_nan_loss,
+ hooks=hooks,
+ max_steps=max_steps)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index dce10e7b0f..467d31c331 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -22,10 +22,8 @@ from __future__ import print_function
import abc
import copy
import inspect
-import itertools
import os
import tempfile
-import time
import numpy as np
import six
@@ -39,10 +37,8 @@ from tensorflow.contrib.framework import deprecated_args
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
from tensorflow.contrib.framework.python.framework import experimental
-from tensorflow.contrib.framework.python.ops import ops as contrib_ops
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import evaluable
-from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
from tensorflow.contrib.learn.python.learn import trainable
@@ -58,7 +54,6 @@ from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.contrib.training.python.training import evaluation
from tensorflow.core.framework import summary_pb2
from tensorflow.python.client import session as tf_session
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
@@ -92,6 +87,25 @@ SCIKIT_DECOUPLE_INSTRUCTIONS = (
' est = Estimator(...) -> est = SKCompat(Estimator(...))')
+def _verify_input_args(x, y, input_fn, feed_fn, batch_size):
+ """Verifies validity of co-existance of input arguments."""
+ if input_fn is None:
+ if x is None:
+ raise ValueError('Either x or input_fn must be provided.')
+
+ if contrib_framework.is_tensor(x) or (y is not None and
+ contrib_framework.is_tensor(y)):
+ raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
+
+ if feed_fn is not None:
+ raise ValueError('Can not provide both feed_fn and x or y.')
+ else:
+ if (x is not None) or (y is not None):
+ raise ValueError('Can not provide both input_fn and x or y.')
+ if batch_size is not None:
+ raise ValueError('Can not provide both input_fn and batch_size.')
+
+
def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
"""Make inputs into input and feed functions.
@@ -110,29 +124,17 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
Raises:
ValueError: Only one of `(x & y)` or `input_fn` must be provided.
"""
- if input_fn is None:
- if x is None:
- raise ValueError('Either x or input_fn must be provided.')
-
- if contrib_framework.is_tensor(x) or (y is not None and
- contrib_framework.is_tensor(y)):
- raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
-
- if feed_fn is not None:
- raise ValueError('Can not provide both feed_fn and x or y.')
-
- df = data_feeder.setup_train_data_feeder(x, y, n_classes=None,
- batch_size=batch_size,
- shuffle=shuffle,
- epochs=epochs)
- return df.input_builder, df.get_feed_dict_fn()
-
- if (x is not None) or (y is not None):
- raise ValueError('Can not provide both input_fn and x or y.')
- if batch_size is not None:
- raise ValueError('Can not provide both input_fn and batch_size.')
-
- return input_fn, feed_fn
+ _verify_input_args(x, y, input_fn, feed_fn, batch_size)
+ if input_fn is not None:
+ return input_fn, feed_fn
+ df = data_feeder.setup_train_data_feeder(
+ x,
+ y,
+ n_classes=None,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ epochs=epochs)
+ return df.input_builder, df.get_feed_dict_fn()
def infer_real_valued_columns_from_input_fn(input_fn):
@@ -311,9 +313,8 @@ def _write_dict_to_summary(output_dir,
dictionary: the `dict` to be written to summary file.
current_global_step: `int`, the current global step.
"""
- logging.info(
- 'Saving dict for global step %d: %s' %
- (current_global_step, _dict_to_str(dictionary)))
+ logging.info('Saving dict for global step %d: %s', current_global_step,
+ _dict_to_str(dictionary))
summary_writer = summary_io.SummaryWriterCache.get(output_dir)
summary_proto = summary_pb2.Summary()
for key in dictionary:
@@ -404,15 +405,24 @@ class BaseEstimator(
"""
if (steps is not None) and (max_steps is not None):
raise ValueError('Can not provide both steps and max_steps.')
+ _verify_input_args(x, y, input_fn, None, batch_size)
+ if x is not None:
+ return SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
- input_fn, feed_fn = _get_input_fn(x, y, input_fn, feed_fn=None,
- batch_size=batch_size, shuffle=True,
- epochs=None)
- loss = self._train_model(input_fn=input_fn,
- feed_fn=feed_fn,
- steps=steps,
- monitors=monitors,
- max_steps=max_steps)
+ if max_steps is not None:
+ try:
+ start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
+ if max_steps <= start_step:
+ logging.info('Skipping training since max_steps has already saved.')
+ return None
+ except: # pylint: disable=bare-except
+ pass
+
+ hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
+ if steps is not None or max_steps is not None:
+ hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
+
+ loss = self._train_model(input_fn=input_fn, hooks=hooks)
logging.info('Loss for final step: %s.', loss)
return self
@@ -485,9 +495,10 @@ class BaseEstimator(
`input_fn` or `feed_fn` is provided.
Or if `metrics` is not `None` or `dict`.
"""
- input_fn, feed_fn = _get_input_fn(x, y, input_fn=input_fn,
- feed_fn=feed_fn, batch_size=batch_size,
- shuffle=False, epochs=1)
+ _verify_input_args(x, y, input_fn, feed_fn, batch_size)
+ if x is not None:
+ return SKCompat(self).score(x, y, batch_size, steps, metrics)
+
if metrics is not None and not isinstance(metrics, dict):
raise ValueError('Metrics argument should be None or dict. '
'Got %s.' % metrics)
@@ -537,11 +548,15 @@ class BaseEstimator(
Raises:
ValueError: If x and input_fn are both provided or both `None`.
"""
- input_fn, feed_fn = _get_input_fn(
- x, None, input_fn=input_fn, feed_fn=None, batch_size=batch_size,
- shuffle=False, epochs=1)
+ _verify_input_args(x, None, input_fn, None, batch_size)
+ if x is not None and not as_iterable:
+ return SKCompat(self).predict(x, batch_size)
+
+ input_fn, feed_fn = _get_input_fn(x, None, input_fn, None, batch_size)
return self._infer_model(
- input_fn=input_fn, feed_fn=feed_fn, outputs=outputs,
+ input_fn=input_fn,
+ feed_fn=feed_fn,
+ outputs=outputs,
as_iterable=as_iterable)
def get_variable_value(self, name):
@@ -728,91 +743,6 @@ class BaseEstimator(
self._labels_info = tensor_signature.create_signatures(labels)
logging.debug('Setting labels info to %s', str(self._labels_info))
- def _train_model(self,
- input_fn,
- steps,
- feed_fn=None,
- init_op=None,
- init_feed_fn=None,
- init_fn=None,
- device_fn=None,
- monitors=None,
- log_every_steps=100,
- fail_on_nan_loss=True,
- max_steps=None):
- # TODO(wicke): Remove this once Model and associated code are gone.
- if hasattr(self._config, 'execution_mode'):
- if self._config.execution_mode not in ('all', 'train'):
- return
-
- # Stagger startup of worker sessions based on task id.
- sleep_secs = min(
- self._config.training_worker_max_startup_secs,
- self._config.task_id *
- self._config.training_worker_session_startup_stagger_secs)
- if sleep_secs:
- logging.info('Waiting %d secs before starting task %d.', sleep_secs,
- self._config.task_id)
- time.sleep(sleep_secs)
-
- # Device allocation
- device_fn = device_fn or self._device_fn
-
- self._graph = ops.Graph()
- with self._graph.as_default() as g, g.device(device_fn):
- random_seed.set_random_seed(self._config.tf_random_seed)
- global_step = contrib_framework.create_global_step(g)
- features, labels = input_fn()
- self._check_inputs(features, labels)
-
- # The default return type of _get_train_ops is ModelFnOps. But there are
- # some subclasses of tf.contrib.learn.Estimator which override this
- # method and use the legacy signature, namely _get_train_ops returns a
- # (train_op, loss) tuple. The following else-statement code covers these
- # cases, but will soon be deleted after the subclasses are updated.
- # TODO(b/32664904): Update subclasses and delete the else-statement.
- train_ops = self._get_train_ops(features, labels)
- if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
- train_op = train_ops.train_op
- loss_op = train_ops.loss
- if self.config.is_chief:
- hooks = train_ops.training_chief_hooks + train_ops.training_hooks
- else:
- hooks = train_ops.training_hooks
- else: # Legacy signature
- if len(train_ops) != 2:
- raise ValueError('Expected a tuple of train_op and loss, got {}'.
- format(train_ops))
- train_op = train_ops[0]
- loss_op = train_ops[1]
- hooks = []
-
- hooks += monitor_lib.replace_monitors_with_hooks(monitors, self)
-
- ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
- return graph_actions._monitored_train( # pylint: disable=protected-access
- graph=g,
- output_dir=self._model_dir,
- train_op=train_op,
- loss_op=loss_op,
- global_step_tensor=global_step,
- init_op=init_op,
- init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
- init_fn=init_fn,
- log_every_steps=log_every_steps,
- supervisor_is_chief=self.config.is_chief,
- supervisor_master=self._config.master,
- supervisor_save_model_secs=self._config.save_checkpoints_secs,
- supervisor_save_model_steps=self._config.save_checkpoints_steps,
- supervisor_save_summaries_steps=self._config.save_summary_steps,
- keep_checkpoint_max=self._config.keep_checkpoint_max,
- keep_checkpoint_every_n_hours=self._config.keep_checkpoint_every_n_hours,
- feed_fn=feed_fn,
- steps=steps,
- fail_on_nan_loss=fail_on_nan_loss,
- hooks=hooks,
- max_steps=max_steps)
-
def _extract_metric_update_ops(self, eval_dict):
"""Separate update operations from metric value operations."""
update_ops = []
@@ -915,8 +845,12 @@ class BaseEstimator(
return result[0]
return result
- def _infer_model(
- self, input_fn, feed_fn=None, outputs=None, as_iterable=True):
+ def _infer_model(self,
+ input_fn,
+ feed_fn=None,
+ outputs=None,
+ as_iterable=True,
+ iterate_batches=False):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
@@ -927,103 +861,152 @@ class BaseEstimator(
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
-
- # The default return type of _get_predict_ops is ModelFnOps. But there are
- # some subclasses of tf.contrib.learn.Estimator which override this
- # method and use the legacy signature, namely _get_predict_ops returns a
- # `predictions` Tensor or dict or Tensors. The following else-statement
- # code covers these cases, but will soon be deleted after the subclasses
- # are updated.
- # TODO(b/32664904): Update subclasses and delete the else-statement.
- infer_ops = self._get_predict_ops(features)
- if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
- predictions = infer_ops.predictions
- else: # Legacy signature
- predictions = infer_ops
-
- # If predictions is single output - wrap it into dict, and remember to
- # return not a dict.
- return_dict = isinstance(predictions, dict)
- if not return_dict:
- predictions = {'predictions': predictions}
-
- # Filter what to run predictions on, if outputs provided.
- if outputs:
- existing_keys = predictions.keys()
- predictions = {
- key: value
- for key, value in six.iteritems(predictions) if key in outputs
- }
- if not predictions:
- raise ValueError('Expected to run at least one output from %s, '
- 'provided %s.' % (existing_keys, outputs))
-
- if as_iterable:
- return self._infer_model_as_iterable(
- checkpoint_path, predictions, feed_fn, return_dict)
- else:
- return self._infer_model_single(
- checkpoint_path, predictions, feed_fn, return_dict)
-
- def _infer_model_single(
- self, checkpoint_path, predictions, feed_fn, return_dict):
- if feed_fn is None:
- preds = graph_actions.infer(checkpoint_path, predictions)
- else:
- def _feed_fn():
- while True:
- yield feed_fn()
-
- outputs = graph_actions.run_feeds(
- output_dict=predictions,
- feed_dicts=_feed_fn(),
- restore_checkpoint_path=checkpoint_path)
- preds = {
- key: np.concatenate([output[key] for output in outputs], axis=0)
- for key in predictions}
-
- return preds if return_dict else preds['predictions']
-
- def _infer_model_as_iterable(
- self, checkpoint_path, predictions, feed_fn, return_dict):
- if feed_fn is None:
- # If there are no queue_runners, the input `predictions` is a
- # constant, and we should stop after the first epoch. If,
- # instead, there are queue_runners, eventually they should throw
- # an `OutOfRangeError`.
- graph = contrib_ops.get_graph_from_inputs(predictions.values())
- if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
- feed_dicts = itertools.repeat(None)
+ infer_ops = self._call_legacy_get_predict_ops(features)
+ predictions = self._filter_predictions(infer_ops.predictions, outputs)
+ mon_sess = monitored_session.MonitoredSession(
+ session_creator=monitored_session.ChiefSessionCreator(
+ checkpoint_filename_with_path=checkpoint_path))
+ if not as_iterable:
+ with mon_sess:
+ if not mon_sess.should_stop():
+ return mon_sess.run(predictions, feed_fn() if feed_fn else None)
else:
- feed_dicts = [None]
- else:
- def _feed_fn():
- while True:
- yield feed_fn()
- feed_dicts = _feed_fn()
-
- try:
- for output_batch in graph_actions.run_feeds_iter(
- output_dict=predictions,
- feed_dicts=feed_dicts,
- restore_checkpoint_path=checkpoint_path):
- # Unpack batches into individual predictions
- if return_dict:
- first_tensor = list(output_batch.values())[0]
+ return self._predict_generator(mon_sess, predictions, feed_fn,
+ iterate_batches)
+
+ def _predict_generator(self, mon_sess, predictions, feed_fn, iterate_batches):
+ with mon_sess:
+ while not mon_sess.should_stop():
+ preds = mon_sess.run(predictions, feed_fn() if feed_fn else None)
+ if iterate_batches:
+ yield preds
+ elif not isinstance(predictions, dict):
+ for pred in preds:
+ yield pred
+ else:
+ first_tensor = list(preds.values())[0]
if isinstance(first_tensor, sparse_tensor.SparseTensorValue):
batch_length = first_tensor.dense_shape[0]
else:
batch_length = first_tensor.shape[0]
for i in range(batch_length):
- yield {key: value[i] for key, value in six.iteritems(output_batch)}
- else:
- for pred in output_batch['predictions']:
- yield pred
+ yield {key: value[i] for key, value in six.iteritems(preds)}
+ if self._is_input_constant(feed_fn, mon_sess.graph):
+ return
+
+ def _is_input_constant(self, feed_fn, graph):
+ # If there are no queue_runners, the input `predictions` is a
+ # constant, and we should stop after the first epoch. If,
+ # instead, there are queue_runners, eventually they should throw
+ # an `OutOfRangeError`.
+ if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
+ return False
+ # data_feeder uses feed_fn to generate `OutOfRangeError`.
+ if feed_fn is not None:
+ return False
+ return True
+
+ def _filter_predictions(self, predictions, outputs):
+ if not outputs:
+ return predictions
+ if not isinstance(predictions, dict):
+ raise ValueError(
+ 'outputs argument is not valid in case of non-dict predictions.')
+ existing_keys = predictions.keys()
+ predictions = {
+ key: value
+ for key, value in six.iteritems(predictions) if key in outputs
+ }
+ if not predictions:
+ raise ValueError('Expected to run at least one output from %s, '
+ 'provided %s.' % (existing_keys, outputs))
+ return predictions
+
+ def _train_model(self, input_fn, hooks):
+ all_hooks = []
+ self._graph = ops.Graph()
+ with self._graph.as_default() as g, g.device(self._device_fn):
+ random_seed.set_random_seed(self._config.tf_random_seed)
+ global_step = contrib_framework.create_global_step(g)
+ features, labels = input_fn()
+ self._check_inputs(features, labels)
+ model_fn_ops = self._call_legacy_get_train_ops(features, labels)
+ ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
+ all_hooks.extend([
+ basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
+ basic_session_run_hooks.LoggingTensorHook(
+ {
+ 'loss': model_fn_ops.loss,
+ 'step': global_step
+ },
+ every_n_iter=100)
+ ])
+ all_hooks.extend(hooks)
- except errors.OutOfRangeError:
- # We fall out of the above loop naturally if feed_fn raises StopIteration,
- # or we catch an OutOfRangeError if we've reached the end of inputs.
- logging.info('Reached end of inputs for predict_iter.')
+ scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
+ if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
+ ops.add_to_collection(
+ ops.GraphKeys.SAVERS,
+ saver.Saver(
+ sharded=True,
+ max_to_keep=self._config.keep_checkpoint_max,
+ defer_build=True))
+
+ chief_hooks = []
+ if (self._config.save_checkpoints_secs or
+ self._config.save_checkpoints_steps):
+ saver_hook_exists = any([
+ isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
+ for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
+ model_fn_ops.training_chief_hooks)
+ ])
+ if not saver_hook_exists:
+ chief_hooks = [
+ basic_session_run_hooks.CheckpointSaverHook(
+ self._model_dir,
+ save_secs=self._config.save_checkpoints_secs,
+ save_steps=self._config.save_checkpoints_steps,
+ scaffold=scaffold)
+ ]
+ with monitored_session.MonitoredTrainingSession(
+ master=self._config.master,
+ is_chief=self._config.is_chief,
+ checkpoint_dir=self._model_dir,
+ scaffold=scaffold,
+ hooks=all_hooks + model_fn_ops.training_hooks,
+ chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
+ save_checkpoint_secs=0, # Saving is handled by a hook.
+ save_summaries_steps=self._config.save_summary_steps,
+ config=None) as mon_sess:
+ loss = None
+ while not mon_sess.should_stop():
+ _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
+ summary_io.SummaryWriterCache.clear()
+ return loss
+
+ def _call_legacy_get_predict_ops(self, features):
+ # The default return type of _get_predict_ops is ModelFnOps. But there are
+ # some subclasses of tf.contrib.learn.Estimator which override this
+ # method and use the legacy signature, namely _get_predict_ops returns a
+ # `predictions` Tensor or dict or Tensors. The following else-statement
+ # code covers these cases, but will soon be deleted after the subclasses
+ # are updated.
+ # TODO(b/32664904): Update subclasses and delete the else-statement.
+ infer_ops = self._get_predict_ops(features)
+ if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
+ return infer_ops
+ return model_fn_lib.ModelFnOps(
+ mode=model_fn_lib.ModeKeys.INFER, predictions=infer_ops)
+
+ def _call_legacy_get_train_ops(self, features, labels):
+ train_ops = self._get_train_ops(features, labels)
+ if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
+ return train_ops
+ return model_fn_lib.ModelFnOps(
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ predictions=None,
+ loss=train_ops[1],
+ train_op=train_ops[0])
def _identity_feature_engineering_fn(features, labels):
@@ -1177,17 +1160,6 @@ class Estimator(BaseEstimator):
"""
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
- # TODO(ispir): delete this function after converting all legacy usages.
- def _call_legacy_get_train_ops(self, features, labels):
- train_ops = self._get_train_ops(features, labels)
- if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
- return train_ops
- return model_fn_lib.ModelFnOps(
- mode=model_fn_lib.ModeKeys.TRAIN,
- predictions=None,
- loss=train_ops[1],
- train_op=train_ops[0])
-
def _get_eval_ops(self, features, labels, metrics):
"""Method that builds model graph and returns evaluation ops.
@@ -1343,114 +1315,6 @@ class Estimator(BaseEstimator):
return export_dir
- @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y',
- 'batch_size')
- def fit(self,
- x=None,
- y=None,
- input_fn=None,
- steps=None,
- batch_size=None,
- monitors=None,
- max_steps=None):
- # pylint: disable=g-doc-args,g-doc-return-or-yield
- """See `Trainable`.
-
- Raises:
- ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
- ValueError: If both `steps` and `max_steps` are not `None`.
- """
- if (steps is not None) and (max_steps is not None):
- raise ValueError('Can not provide both steps and max_steps.')
- if max_steps is not None:
- try:
- start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
- if max_steps <= start_step:
- logging.info('Skipping training since max_steps has already saved.')
- return None
- except: # pylint: disable=bare-except
- pass
-
- hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
- if steps is not None or max_steps is not None:
- hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
-
- input_fn, feed_fn = _get_input_fn(
- x,
- y,
- input_fn,
- feed_fn=None,
- batch_size=batch_size,
- shuffle=True,
- epochs=None)
- if feed_fn:
- hooks.append(_FeedFnHook(feed_fn))
- loss = self._train_model_v2(input_fn=input_fn, hooks=hooks)
- logging.info('Loss for final step: %s.', loss)
- return self
-
- def _train_model_v2(self, input_fn, hooks):
- all_hooks = []
- self._graph = ops.Graph()
- with self._graph.as_default() as g, g.device(self._device_fn):
- random_seed.set_random_seed(self._config.tf_random_seed)
- global_step = contrib_framework.create_global_step(g)
- features, labels = input_fn()
- self._check_inputs(features, labels)
- model_fn_ops = self._call_legacy_get_train_ops(features, labels)
- ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
- all_hooks.extend([
- basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
- basic_session_run_hooks.LoggingTensorHook(
- {
- 'loss': model_fn_ops.loss,
- 'step': global_step
- },
- every_n_iter=100)
- ])
- all_hooks.extend(hooks)
-
- scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
- if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
- ops.add_to_collection(
- ops.GraphKeys.SAVERS,
- saver.Saver(
- sharded=True,
- max_to_keep=self._config.keep_checkpoint_max,
- defer_build=True))
-
- chief_hooks = []
- if (self._config.save_checkpoints_secs or
- self._config.save_checkpoints_steps):
- saver_hook_exists = any([
- isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
- for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
- model_fn_ops.training_chief_hooks)
- ])
- if not saver_hook_exists:
- chief_hooks = [
- basic_session_run_hooks.CheckpointSaverHook(
- self._model_dir,
- save_secs=self._config.save_checkpoints_secs,
- save_steps=self._config.save_checkpoints_steps,
- scaffold=scaffold)
- ]
- with monitored_session.MonitoredTrainingSession(
- master=self._config.master,
- is_chief=self._config.is_chief,
- checkpoint_dir=self._model_dir,
- scaffold=scaffold,
- hooks=all_hooks + model_fn_ops.training_hooks,
- chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
- save_checkpoint_secs=0, # Saving is handled by a hook.
- save_summaries_steps=self._config.save_summary_steps,
- config=None) as mon_sess:
- loss = None
- while not mon_sess.should_stop():
- _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
- summary_io.SummaryWriterCache.clear()
- return loss
-
class _FeedFnHook(session_run_hook.SessionRunHook):
"""Runs feed_fn and sets the feed_dict accordingly."""
@@ -1509,6 +1373,17 @@ class SKCompat(sklearn.BaseEstimator):
input_fn, feed_fn = _get_input_fn(
x, None, input_fn=None, feed_fn=None, batch_size=batch_size,
shuffle=False, epochs=1)
- return self._estimator._infer_model(
- input_fn=input_fn, feed_fn=feed_fn, outputs=outputs,
- as_iterable=False)
+ results = list(
+ self._estimator._infer_model(
+ input_fn=input_fn,
+ feed_fn=feed_fn,
+ outputs=outputs,
+ as_iterable=True,
+ iterate_batches=True))
+ if not isinstance(results[0], dict):
+ return np.concatenate([output for output in results], axis=0)
+ return {
+ key: np.concatenate(
+ [output[key] for output in results], axis=0)
+ for key in results[0]
+ }