From d97e525c1fa601b4b1b0a3cffd9e579201d167ef Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Wed, 22 Aug 2018 15:09:08 -0700 Subject: Switch to using variable._in_graph_mode instead of context.executing_eagerly() in optimizer.variables() Remove global collection usage from the Keras model to estimator flow. PiperOrigin-RevId: 209837803 --- tensorflow/python/estimator/keras.py | 11 +---------- tensorflow/python/keras/backend.py | 5 ++--- tensorflow/python/keras/models.py | 1 + tensorflow/python/training/optimizer.py | 7 +++---- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index ce6ad47c01..6361c6acc1 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -36,7 +36,6 @@ from tensorflow.python.keras import optimizers from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_module -from tensorflow.python.ops import variables as variables_module from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants @@ -315,15 +314,7 @@ def _save_first_checkpoint(keras_model, custom_objects, config): if not model.train_function: # pylint: disable=protected-access model._make_train_function() - # We are using global variables collection here because: - # estimator runs eager mode under context.graph_mode() context manager - # When we try to get all the TF optimizer variables using - # optimizer.variables() we try to return variables that belong to the - # current graph. This check (variable.op.graph is current_graph) will - # error as the context is graph mode but variables are eager. - # TODO(psv): investigate this and see if we can remove the usage of - # collection here. - K._initialize_variables(sess, variables_module.global_variables()) + K._initialize_variables(sess) # pylint: enable=protected-access saver = saver_lib.Saver() latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt') diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 62433a400b..b52ab7f05c 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -696,10 +696,9 @@ def _get_variables(graph=None): return variables -def _initialize_variables(session, variables=None): +def _initialize_variables(session): """Utility to initialize uninitialized variables on the fly.""" - if variables is None: - variables = _get_variables(ops.get_default_graph()) + variables = _get_variables(ops.get_default_graph()) candidate_vars = [] for v in variables: if not getattr(v, '_keras_initialized', False): diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index e3032acbfd..6bc256d2ec 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -447,6 +447,7 @@ def clone_and_build_model( elif model.optimizer: if isinstance(model.optimizer, optimizers.TFOptimizer): optimizer = model.optimizer + K.track_tf_optimizer(optimizer) else: optimizer_config = model.optimizer.get_config() optimizer = model.optimizer.__class__.from_config(optimizer_config) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 1b6bce2865..2304a461c1 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -772,16 +772,15 @@ class Optimizer( Returns: A list of variables. """ - executing_eagerly = context.executing_eagerly() current_graph = ops.get_default_graph() def _from_current_graph(variable): - if executing_eagerly: + if variable._in_graph_mode: # pylint: disable=protected-access + return variable.op.graph is current_graph + else: # No variable.op in eager mode. We don't expect lots of eager graphs, # but behavior should be consistent with graph mode. return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access - else: - return variable.op.graph is current_graph optimizer_variables = [v for v in self._non_slot_variables() if _from_current_graph(v)] -- cgit v1.2.3