aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator.py
diff options
context:
space:
mode:
authorGravatar Avijit <Avijit.Chakraborty@intel.com>2018-08-12 16:21:41 -0700
committerGravatar Avijit <Avijit.Chakraborty@intel.com>2018-08-12 16:21:41 -0700
commit9523a98466d16cf01fc76a67b489f1124cf626ac (patch)
treebd4c460b67fab60c2fb1a6c56bf22d1cbb5391e6 /tensorflow/python/estimator/estimator.py
parent93e950c308071071f35d6dcb35b9f91b8a34876c (diff)
parent1a22b0b982fa1a953651b98af8f3cd30542048fd (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/python/estimator/estimator.py')
-rw-r--r--tensorflow/python/estimator/estimator.py107
1 files changed, 72 insertions, 35 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 43deb8bc6c..b8cd55c806 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -50,9 +50,10 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
-from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import distribute as distribute_lib
@@ -104,7 +105,7 @@ class Estimator(object):
constructor enforces this). Subclasses should use `model_fn` to configure
the base class, and may add methods implementing specialized functionality.
- @compatbility(eager)
+ @compatibility(eager)
Calling methods of `Estimator` will work while eager execution is enabled.
However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator`
will switch to graph model before calling all user-provided functions (incl.
@@ -128,7 +129,7 @@ class Estimator(object):
```
For more details on warm-start configuration, see
- @{tf.estimator.WarmStartSettings$WarmStartSettings}.
+ `tf.estimator.WarmStartSettings`.
Args:
model_fn: Model function. Follows the signature:
@@ -345,7 +346,23 @@ class Estimator(object):
return self
def _convert_train_steps_to_hooks(self, steps, max_steps):
+ """Create hooks to run correct number of steps in training.
+
+ Args:
+ steps: number of steps to run during training.
+ max_steps: maximum number of steps to be run during training. It'll be
+ the maximum number of steps the model will train to after restoring
+ from checkpoint even across multiple estimator.train calls.
+
+ Returns:
+ List of hooks to be passed to the estimator.
+ """
if steps is not None or max_steps is not None:
+ if self._train_distribution:
+ steps_per_run = getattr(self._train_distribution, 'steps_per_run', 1)
+ if steps_per_run > 1:
+ return [basic_session_run_hooks._MultiStepStopAtStepHook( # pylint: disable=protected-access
+ steps, max_steps, steps_per_run)]
return [training.StopAtStepHook(steps, max_steps)]
else:
return []
@@ -988,7 +1005,7 @@ class Estimator(object):
def _get_features_and_labels_from_input_fn(self, input_fn, mode,
distribution=None):
"""Extracts the `features` and labels from return values of `input_fn`."""
- if distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
+ if distribution is not None:
result = distribution.distribute_dataset(
lambda: self._call_input_fn(input_fn, mode))
else:
@@ -1027,7 +1044,7 @@ class Estimator(object):
"""Creates the global step tensor in graph.
The global step tensor must be an integer type with name 'global_step' and
- be added to the collection @{tf.GraphKeys.GLOBAL_STEP}.
+ be added to the collection `tf.GraphKeys.GLOBAL_STEP`.
Args:
graph: The graph in which to create the global step tensor.
@@ -1184,6 +1201,10 @@ class Estimator(object):
worker_hooks = []
with ops.Graph().as_default() as g:
+ # We want to create the iterations variable outside the distribution scope
+ # as that is just stored on the host and mainly used to drive the loop
+ # and doesn't need to be a Mirrored/Device variable.
+ steps_per_run_variable = training.get_or_create_steps_per_run_variable()
with self._train_distribution.scope():
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -1215,19 +1236,21 @@ class Estimator(object):
labels,
model_fn_lib.ModeKeys.TRAIN,
self.config)
- ctx.last_step_outputs = estimator_spec.loss
- ctx.non_tensor_outputs = {'estimator_spec': estimator_spec}
- with ops.control_dependencies([estimator_spec.train_op]):
- return array_ops.identity(estimator_spec.loss)
+ ctx.set_last_step_output(
+ name='loss',
+ output=estimator_spec.loss,
+ aggregation=distribute_lib.get_loss_reduction())
+ ctx.set_non_tensor_output(
+ name='estimator_spec', output=estimator_spec)
+ return estimator_spec.train_op
# Create new train_op post graph rewrites
- # TODO(sourabhbajaj): Make sure train_steps and tpu_iterations
- # work correctly. Currently hardcoded at 2
initial_training_loss = constant_op.constant(1e7)
- distributed_train_op, tpu_result, ctx = \
- self._train_distribution._run_steps_on_dataset( # pylint: disable=protected-access
- step_fn, iterator, iterations=2,
- initial_loop_values=initial_training_loss)
+ ctx = self._train_distribution.run_steps_on_dataset(
+ step_fn, iterator, iterations=steps_per_run_variable,
+ initial_loop_values={'loss': initial_training_loss})
+ distributed_train_op = ctx.run_op
+ tpu_result = ctx.last_step_outputs
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
else:
features, labels, input_hooks = (
@@ -1263,22 +1286,22 @@ class Estimator(object):
# TODO(sourabhbajaj): Merge the two code paths and clean up the code
if is_tpu_strategy:
- distributed_loss = tpu_result
+ loss = tpu_result['loss']
worker_hooks.append(
estimator_util.StrategyInitFinalizeHook(
- self._train_distribution.get_initialization_ops,
- self._train_distribution.get_finalize_ops))
+ self._train_distribution.initialize,
+ self._train_distribution.finalize))
else:
- distributed_loss = grouped_estimator_spec.loss
+ loss = self._train_distribution.unwrap(
+ self._train_distribution.reduce(
+ distribute_lib.get_loss_reduction(),
+ grouped_estimator_spec.loss,
+ destinations='/device:CPU:0'))[0]
distributed_train_op = grouped_estimator_spec.train_op
estimator_spec = model_fn_lib.EstimatorSpec(
mode=grouped_estimator_spec.mode,
- loss=self._train_distribution.unwrap(
- self._train_distribution.reduce(
- distribute_lib.get_loss_reduction(),
- distributed_loss,
- destinations='/device:CPU:0'))[0],
+ loss=loss,
train_op=self._train_distribution.group(distributed_train_op),
training_hooks=training_hooks,
training_chief_hooks=training_chief_hooks,
@@ -1783,10 +1806,24 @@ def _write_dict_to_summary(output_dir,
logging.warn('Skipping summary for %s, cannot parse string to Summary.',
key)
continue
+ elif isinstance(dictionary[key], np.ndarray):
+ value = summary_proto.value.add()
+ value.tag = key
+ value.node_name = key
+ tensor_proto = tensor_util.make_tensor_proto(dictionary[key])
+ value.tensor.CopyFrom(tensor_proto)
+ # pylint: disable=line-too-long
+ logging.info(
+ 'Summary for np.ndarray is not visible in Tensorboard by default. '
+ 'Consider using a Tensorboard plugin for visualization (see '
+ 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md'
+ ' for more information).')
+ # pylint: enable=line-too-long
else:
logging.warn(
'Skipping summary for %s, must be a float, np.float32, np.int64, '
- 'np.int32 or int or a serialized string of Summary.', key)
+ 'np.int32 or int or np.ndarray or a serialized string of Summary.',
+ key)
summary_writer.add_summary(summary_proto, current_global_step)
summary_writer.flush()
@@ -2001,14 +2038,11 @@ class WarmStartSettings(
def _get_saved_model_ckpt(saved_model_dir):
"""Return path to variables checkpoint in a SavedModel directory."""
if not gfile.Exists(
- os.path.join(compat.as_bytes(saved_model_dir),
- compat.as_bytes('variables/variables.index'))):
+ os.path.join(saved_model_utils.get_variables_dir(saved_model_dir),
+ compat.as_text('variables.index'))):
raise ValueError('Directory provided has an invalid SavedModel format: %s'
% saved_model_dir)
- return os.path.join(
- compat.as_bytes(saved_model_dir),
- compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY,
- constants.VARIABLES_FILENAME)))
+ return saved_model_utils.get_variables_path(saved_model_dir)
def _get_default_warm_start_settings(warm_start_from):
@@ -2030,12 +2064,15 @@ def _get_default_warm_start_settings(warm_start_from):
if isinstance(warm_start_from, (six.string_types, six.binary_type)):
# Infer that this is a SavedModel if export_path +
# 'variables/variables.index' exists, and if so, construct the
- # WarmStartSettings pointing to export_path + 'variables/variables'.
- if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from),
- compat.as_bytes('variables/variables.index'))):
+ # WarmStartSettings pointing to the variables path
+ # (export_path + 'variables/variables').
+ if gfile.Exists(os.path.join(
+ saved_model_utils.get_variables_dir(warm_start_from),
+ compat.as_text('variables.index'))):
logging.info('Warm-starting from a SavedModel')
return WarmStartSettings(
- ckpt_to_initialize_from=_get_saved_model_ckpt(warm_start_from))
+ ckpt_to_initialize_from=saved_model_utils.get_variables_path(
+ warm_start_from))
return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
elif isinstance(warm_start_from, WarmStartSettings):
return warm_start_from