diff options
author | 2018-08-12 16:21:41 -0700 | |
---|---|---|
committer | 2018-08-12 16:21:41 -0700 | |
commit | 9523a98466d16cf01fc76a67b489f1124cf626ac (patch) | |
tree | bd4c460b67fab60c2fb1a6c56bf22d1cbb5391e6 /tensorflow/python/estimator/estimator.py | |
parent | 93e950c308071071f35d6dcb35b9f91b8a34876c (diff) | |
parent | 1a22b0b982fa1a953651b98af8f3cd30542048fd (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/python/estimator/estimator.py')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 107 |
1 files changed, 72 insertions, 35 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 43deb8bc6c..b8cd55c806 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -50,9 +50,10 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder -from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import checkpoint_management from tensorflow.python.training import device_setter from tensorflow.python.training import distribute as distribute_lib @@ -104,7 +105,7 @@ class Estimator(object): constructor enforces this). Subclasses should use `model_fn` to configure the base class, and may add methods implementing specialized functionality. - @compatbility(eager) + @compatibility(eager) Calling methods of `Estimator` will work while eager execution is enabled. However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator` will switch to graph model before calling all user-provided functions (incl. @@ -128,7 +129,7 @@ class Estimator(object): ``` For more details on warm-start configuration, see - @{tf.estimator.WarmStartSettings$WarmStartSettings}. + `tf.estimator.WarmStartSettings`. Args: model_fn: Model function. Follows the signature: @@ -345,7 +346,23 @@ class Estimator(object): return self def _convert_train_steps_to_hooks(self, steps, max_steps): + """Create hooks to run correct number of steps in training. + + Args: + steps: number of steps to run during training. + max_steps: maximum number of steps to be run during training. It'll be + the maximum number of steps the model will train to after restoring + from checkpoint even across multiple estimator.train calls. + + Returns: + List of hooks to be passed to the estimator. + """ if steps is not None or max_steps is not None: + if self._train_distribution: + steps_per_run = getattr(self._train_distribution, 'steps_per_run', 1) + if steps_per_run > 1: + return [basic_session_run_hooks._MultiStepStopAtStepHook( # pylint: disable=protected-access + steps, max_steps, steps_per_run)] return [training.StopAtStepHook(steps, max_steps)] else: return [] @@ -988,7 +1005,7 @@ class Estimator(object): def _get_features_and_labels_from_input_fn(self, input_fn, mode, distribution=None): """Extracts the `features` and labels from return values of `input_fn`.""" - if distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN: + if distribution is not None: result = distribution.distribute_dataset( lambda: self._call_input_fn(input_fn, mode)) else: @@ -1027,7 +1044,7 @@ class Estimator(object): """Creates the global step tensor in graph. The global step tensor must be an integer type with name 'global_step' and - be added to the collection @{tf.GraphKeys.GLOBAL_STEP}. + be added to the collection `tf.GraphKeys.GLOBAL_STEP`. Args: graph: The graph in which to create the global step tensor. @@ -1184,6 +1201,10 @@ class Estimator(object): worker_hooks = [] with ops.Graph().as_default() as g: + # We want to create the iterations variable outside the distribution scope + # as that is just stored on the host and mainly used to drive the loop + # and doesn't need to be a Mirrored/Device variable. + steps_per_run_variable = training.get_or_create_steps_per_run_variable() with self._train_distribution.scope(): random_seed.set_random_seed(self._config.tf_random_seed) @@ -1215,19 +1236,21 @@ class Estimator(object): labels, model_fn_lib.ModeKeys.TRAIN, self.config) - ctx.last_step_outputs = estimator_spec.loss - ctx.non_tensor_outputs = {'estimator_spec': estimator_spec} - with ops.control_dependencies([estimator_spec.train_op]): - return array_ops.identity(estimator_spec.loss) + ctx.set_last_step_output( + name='loss', + output=estimator_spec.loss, + aggregation=distribute_lib.get_loss_reduction()) + ctx.set_non_tensor_output( + name='estimator_spec', output=estimator_spec) + return estimator_spec.train_op # Create new train_op post graph rewrites - # TODO(sourabhbajaj): Make sure train_steps and tpu_iterations - # work correctly. Currently hardcoded at 2 initial_training_loss = constant_op.constant(1e7) - distributed_train_op, tpu_result, ctx = \ - self._train_distribution._run_steps_on_dataset( # pylint: disable=protected-access - step_fn, iterator, iterations=2, - initial_loop_values=initial_training_loss) + ctx = self._train_distribution.run_steps_on_dataset( + step_fn, iterator, iterations=steps_per_run_variable, + initial_loop_values={'loss': initial_training_loss}) + distributed_train_op = ctx.run_op + tpu_result = ctx.last_step_outputs grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec'] else: features, labels, input_hooks = ( @@ -1263,22 +1286,22 @@ class Estimator(object): # TODO(sourabhbajaj): Merge the two code paths and clean up the code if is_tpu_strategy: - distributed_loss = tpu_result + loss = tpu_result['loss'] worker_hooks.append( estimator_util.StrategyInitFinalizeHook( - self._train_distribution.get_initialization_ops, - self._train_distribution.get_finalize_ops)) + self._train_distribution.initialize, + self._train_distribution.finalize)) else: - distributed_loss = grouped_estimator_spec.loss + loss = self._train_distribution.unwrap( + self._train_distribution.reduce( + distribute_lib.get_loss_reduction(), + grouped_estimator_spec.loss, + destinations='/device:CPU:0'))[0] distributed_train_op = grouped_estimator_spec.train_op estimator_spec = model_fn_lib.EstimatorSpec( mode=grouped_estimator_spec.mode, - loss=self._train_distribution.unwrap( - self._train_distribution.reduce( - distribute_lib.get_loss_reduction(), - distributed_loss, - destinations='/device:CPU:0'))[0], + loss=loss, train_op=self._train_distribution.group(distributed_train_op), training_hooks=training_hooks, training_chief_hooks=training_chief_hooks, @@ -1783,10 +1806,24 @@ def _write_dict_to_summary(output_dir, logging.warn('Skipping summary for %s, cannot parse string to Summary.', key) continue + elif isinstance(dictionary[key], np.ndarray): + value = summary_proto.value.add() + value.tag = key + value.node_name = key + tensor_proto = tensor_util.make_tensor_proto(dictionary[key]) + value.tensor.CopyFrom(tensor_proto) + # pylint: disable=line-too-long + logging.info( + 'Summary for np.ndarray is not visible in Tensorboard by default. ' + 'Consider using a Tensorboard plugin for visualization (see ' + 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md' + ' for more information).') + # pylint: enable=line-too-long else: logging.warn( 'Skipping summary for %s, must be a float, np.float32, np.int64, ' - 'np.int32 or int or a serialized string of Summary.', key) + 'np.int32 or int or np.ndarray or a serialized string of Summary.', + key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() @@ -2001,14 +2038,11 @@ class WarmStartSettings( def _get_saved_model_ckpt(saved_model_dir): """Return path to variables checkpoint in a SavedModel directory.""" if not gfile.Exists( - os.path.join(compat.as_bytes(saved_model_dir), - compat.as_bytes('variables/variables.index'))): + os.path.join(saved_model_utils.get_variables_dir(saved_model_dir), + compat.as_text('variables.index'))): raise ValueError('Directory provided has an invalid SavedModel format: %s' % saved_model_dir) - return os.path.join( - compat.as_bytes(saved_model_dir), - compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY, - constants.VARIABLES_FILENAME))) + return saved_model_utils.get_variables_path(saved_model_dir) def _get_default_warm_start_settings(warm_start_from): @@ -2030,12 +2064,15 @@ def _get_default_warm_start_settings(warm_start_from): if isinstance(warm_start_from, (six.string_types, six.binary_type)): # Infer that this is a SavedModel if export_path + # 'variables/variables.index' exists, and if so, construct the - # WarmStartSettings pointing to export_path + 'variables/variables'. - if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from), - compat.as_bytes('variables/variables.index'))): + # WarmStartSettings pointing to the variables path + # (export_path + 'variables/variables'). + if gfile.Exists(os.path.join( + saved_model_utils.get_variables_dir(warm_start_from), + compat.as_text('variables.index'))): logging.info('Warm-starting from a SavedModel') return WarmStartSettings( - ckpt_to_initialize_from=_get_saved_model_ckpt(warm_start_from)) + ckpt_to_initialize_from=saved_model_utils.get_variables_path( + warm_start_from)) return WarmStartSettings(ckpt_to_initialize_from=warm_start_from) elif isinstance(warm_start_from, WarmStartSettings): return warm_start_from |