aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/estimator.py')
-rw-r--r--tensorflow/python/estimator/estimator.py75
1 files changed, 54 insertions, 21 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 350a95eea1..cc5a61b54e 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -29,8 +29,6 @@ import six
from google.protobuf import message
from tensorflow.core.framework import summary_pb2
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
@@ -216,11 +214,7 @@ class Estimator(object):
logging.info('Using config: %s', str(vars(self._config)))
if self._config.session_config is None:
- rewrite_opts = rewriter_config_pb2.RewriterConfig(
- meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)
- graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts)
- self._session_config = config_pb2.ConfigProto(
- allow_soft_placement=True, graph_options=graph_opts)
+ self._session_config = run_config.get_default_session_config()
else:
self._session_config = self._config.session_config
@@ -573,12 +567,19 @@ class Estimator(object):
def _assert_members_are_not_overridden(self):
"""Asserts members of `Estimator` are not overridden."""
+ # TPUEstimator is special cased (owned by TF).
+ if self.__class__.__name__ == 'TPUEstimator':
+ return
+
allowed_overrides = set([
- '_call_input_fn', '_create_global_step',
+ '_call_input_fn', '_call_model_fn',
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_tf_api_names', '_estimator_api_names', '_estimator_api_constants',
+ '_create_global_step', '_create_and_assert_global_step',
+ '_tf_api_names', '_tf_api_names_v1', '_estimator_api_names',
+ '_estimator_api_names_v1', '_estimator_api_constants',
+ '_estimator_api_constants_v1',
'_validate_features_in_predict_input',
- '_call_model_fn', '_add_meta_graph_for_mode'
+ '_add_meta_graph_for_mode'
])
estimator_members = set([m for m in Estimator.__dict__.keys()
if not m.startswith('__')])
@@ -905,9 +906,10 @@ class Estimator(object):
with tf_session.Session(config=self._session_config) as session:
- local_init_op = (
- estimator_spec.scaffold.local_init_op or
- monitored_session.Scaffold.default_local_init_op())
+ if estimator_spec.scaffold.local_init_op is not None:
+ local_init_op = estimator_spec.scaffold.local_init_op
+ else:
+ local_init_op = monitored_session.Scaffold.default_local_init_op()
# This saver will be used both for restoring variables now,
# and in saving out the metagraph below. This ensures that any
@@ -1159,13 +1161,19 @@ class Estimator(object):
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
- training_util._get_or_create_global_step_read() # pylint: disable=protected-access
+
+ # Skip creating a read variable if _create_and_assert_global_step
+ # returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
+ if global_step_tensor is not None:
+ training_util._get_or_create_global_step_read(g) # pylint: disable=protected-access
+
features, labels, input_hooks = (
self._get_features_and_labels_from_input_fn(
input_fn, model_fn_lib.ModeKeys.TRAIN))
worker_hooks.extend(input_hooks)
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
+ global_step_tensor = training_util.get_global_step(g)
return self._train_with_estimator_spec(estimator_spec, worker_hooks,
hooks, global_step_tensor,
saving_listeners)
@@ -1452,13 +1460,13 @@ class Estimator(object):
def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None):
"""Builds the graph and related hooks to run evaluation."""
random_seed.set_random_seed(self._config.tf_random_seed)
- global_step_tensor = self._create_and_assert_global_step(
- ops.get_default_graph())
+ self._create_and_assert_global_step(ops.get_default_graph())
features, labels, input_hooks = (
self._get_features_and_labels_from_input_fn(input_fn,
model_fn_lib.ModeKeys.EVAL))
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL, self.config)
+ global_step_tensor = training_util.get_global_step(ops.get_default_graph())
# Call to warm_start has to be after model_fn is called.
self._maybe_warm_start(checkpoint_path)
@@ -1484,7 +1492,21 @@ class Estimator(object):
all_hooks.extend(hooks)
all_hooks.extend(list(estimator_spec.evaluation_hooks or []))
- return estimator_spec.scaffold, update_op, eval_dict, all_hooks
+ # New local variables have been added, so update the estimator spec's
+ # local init op if it was defined.
+ scaffold = estimator_spec.scaffold
+ if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op:
+ # Ensure that eval step has been created before updating local init op.
+ evaluation._get_or_create_eval_step() # pylint: disable=protected-access
+
+ scaffold = monitored_session.Scaffold(
+ local_init_op=control_flow_ops.group(
+ estimator_spec.scaffold.local_init_op,
+ monitored_session.Scaffold.default_local_init_op()),
+ copy_from_scaffold=scaffold
+ )
+
+ return scaffold, update_op, eval_dict, all_hooks
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,
all_hooks, output_dir):
@@ -1915,6 +1937,19 @@ class WarmStartSettings(
)
+def _get_saved_model_ckpt(saved_model_dir):
+ """Return path to variables checkpoint in a SavedModel directory."""
+ if not gfile.Exists(
+ os.path.join(compat.as_bytes(saved_model_dir),
+ compat.as_bytes('variables/variables.index'))):
+ raise ValueError('Directory provided has an invalid SavedModel format: %s'
+ % saved_model_dir)
+ return os.path.join(
+ compat.as_bytes(saved_model_dir),
+ compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY,
+ constants.VARIABLES_FILENAME)))
+
+
def _get_default_warm_start_settings(warm_start_from):
"""Returns default WarmStartSettings.
@@ -1938,10 +1973,8 @@ def _get_default_warm_start_settings(warm_start_from):
if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from),
compat.as_bytes('variables/variables.index'))):
logging.info('Warm-starting from a SavedModel')
- return WarmStartSettings(ckpt_to_initialize_from=os.path.join(
- compat.as_bytes(warm_start_from),
- compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY,
- constants.VARIABLES_FILENAME))))
+ return WarmStartSettings(
+ ckpt_to_initialize_from=_get_saved_model_ckpt(warm_start_from))
return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
elif isinstance(warm_start_from, WarmStartSettings):
return warm_start_from