diff options
Diffstat (limited to 'tensorflow/python/estimator/estimator.py')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 75 |
1 files changed, 54 insertions, 21 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 350a95eea1..cc5a61b54e 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -29,8 +29,6 @@ import six from google.protobuf import message from tensorflow.core.framework import summary_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as tf_session from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib @@ -216,11 +214,7 @@ class Estimator(object): logging.info('Using config: %s', str(vars(self._config))) if self._config.session_config is None: - rewrite_opts = rewriter_config_pb2.RewriterConfig( - meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) - graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) - self._session_config = config_pb2.ConfigProto( - allow_soft_placement=True, graph_options=graph_opts) + self._session_config = run_config.get_default_session_config() else: self._session_config = self._config.session_config @@ -573,12 +567,19 @@ class Estimator(object): def _assert_members_are_not_overridden(self): """Asserts members of `Estimator` are not overridden.""" + # TPUEstimator is special cased (owned by TF). + if self.__class__.__name__ == 'TPUEstimator': + return + allowed_overrides = set([ - '_call_input_fn', '_create_global_step', + '_call_input_fn', '_call_model_fn', '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks', - '_tf_api_names', '_estimator_api_names', '_estimator_api_constants', + '_create_global_step', '_create_and_assert_global_step', + '_tf_api_names', '_tf_api_names_v1', '_estimator_api_names', + '_estimator_api_names_v1', '_estimator_api_constants', + '_estimator_api_constants_v1', '_validate_features_in_predict_input', - '_call_model_fn', '_add_meta_graph_for_mode' + '_add_meta_graph_for_mode' ]) estimator_members = set([m for m in Estimator.__dict__.keys() if not m.startswith('__')]) @@ -905,9 +906,10 @@ class Estimator(object): with tf_session.Session(config=self._session_config) as session: - local_init_op = ( - estimator_spec.scaffold.local_init_op or - monitored_session.Scaffold.default_local_init_op()) + if estimator_spec.scaffold.local_init_op is not None: + local_init_op = estimator_spec.scaffold.local_init_op + else: + local_init_op = monitored_session.Scaffold.default_local_init_op() # This saver will be used both for restoring variables now, # and in saving out the metagraph below. This ensures that any @@ -1159,13 +1161,19 @@ class Estimator(object): with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) - training_util._get_or_create_global_step_read() # pylint: disable=protected-access + + # Skip creating a read variable if _create_and_assert_global_step + # returns None (e.g. tf.contrib.estimator.SavedModelEstimator). + if global_step_tensor is not None: + training_util._get_or_create_global_step_read(g) # pylint: disable=protected-access + features, labels, input_hooks = ( self._get_features_and_labels_from_input_fn( input_fn, model_fn_lib.ModeKeys.TRAIN)) worker_hooks.extend(input_hooks) estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) + global_step_tensor = training_util.get_global_step(g) return self._train_with_estimator_spec(estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners) @@ -1452,13 +1460,13 @@ class Estimator(object): def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None): """Builds the graph and related hooks to run evaluation.""" random_seed.set_random_seed(self._config.tf_random_seed) - global_step_tensor = self._create_and_assert_global_step( - ops.get_default_graph()) + self._create_and_assert_global_step(ops.get_default_graph()) features, labels, input_hooks = ( self._get_features_and_labels_from_input_fn(input_fn, model_fn_lib.ModeKeys.EVAL)) estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL, self.config) + global_step_tensor = training_util.get_global_step(ops.get_default_graph()) # Call to warm_start has to be after model_fn is called. self._maybe_warm_start(checkpoint_path) @@ -1484,7 +1492,21 @@ class Estimator(object): all_hooks.extend(hooks) all_hooks.extend(list(estimator_spec.evaluation_hooks or [])) - return estimator_spec.scaffold, update_op, eval_dict, all_hooks + # New local variables have been added, so update the estimator spec's + # local init op if it was defined. + scaffold = estimator_spec.scaffold + if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op: + # Ensure that eval step has been created before updating local init op. + evaluation._get_or_create_eval_step() # pylint: disable=protected-access + + scaffold = monitored_session.Scaffold( + local_init_op=control_flow_ops.group( + estimator_spec.scaffold.local_init_op, + monitored_session.Scaffold.default_local_init_op()), + copy_from_scaffold=scaffold + ) + + return scaffold, update_op, eval_dict, all_hooks def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict, all_hooks, output_dir): @@ -1915,6 +1937,19 @@ class WarmStartSettings( ) +def _get_saved_model_ckpt(saved_model_dir): + """Return path to variables checkpoint in a SavedModel directory.""" + if not gfile.Exists( + os.path.join(compat.as_bytes(saved_model_dir), + compat.as_bytes('variables/variables.index'))): + raise ValueError('Directory provided has an invalid SavedModel format: %s' + % saved_model_dir) + return os.path.join( + compat.as_bytes(saved_model_dir), + compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY, + constants.VARIABLES_FILENAME))) + + def _get_default_warm_start_settings(warm_start_from): """Returns default WarmStartSettings. @@ -1938,10 +1973,8 @@ def _get_default_warm_start_settings(warm_start_from): if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from), compat.as_bytes('variables/variables.index'))): logging.info('Warm-starting from a SavedModel') - return WarmStartSettings(ckpt_to_initialize_from=os.path.join( - compat.as_bytes(warm_start_from), - compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY, - constants.VARIABLES_FILENAME)))) + return WarmStartSettings( + ckpt_to_initialize_from=_get_saved_model_ckpt(warm_start_from)) return WarmStartSettings(ckpt_to_initialize_from=warm_start_from) elif isinstance(warm_start_from, WarmStartSettings): return warm_start_from |