aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-08-22 15:09:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 15:21:16 -0700
commitd97e525c1fa601b4b1b0a3cffd9e579201d167ef (patch)
tree4ea8d86943b547c5f0fc37165d3d2a7753ff9aee
parentcd199a89dbffdd55aa2fc89acb874763382f196d (diff)
Switch to using variable._in_graph_mode instead of context.executing_eagerly() in optimizer.variables()
Remove global collection usage from the Keras model to estimator flow. PiperOrigin-RevId: 209837803
-rw-r--r--tensorflow/python/estimator/keras.py11
-rw-r--r--tensorflow/python/keras/backend.py5
-rw-r--r--tensorflow/python/keras/models.py1
-rw-r--r--tensorflow/python/training/optimizer.py7
4 files changed, 7 insertions, 17 deletions
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index ce6ad47c01..6361c6acc1 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -36,7 +36,6 @@ from tensorflow.python.keras import optimizers
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
-from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
@@ -315,15 +314,7 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
if not model.train_function:
# pylint: disable=protected-access
model._make_train_function()
- # We are using global variables collection here because:
- # estimator runs eager mode under context.graph_mode() context manager
- # When we try to get all the TF optimizer variables using
- # optimizer.variables() we try to return variables that belong to the
- # current graph. This check (variable.op.graph is current_graph) will
- # error as the context is graph mode but variables are eager.
- # TODO(psv): investigate this and see if we can remove the usage of
- # collection here.
- K._initialize_variables(sess, variables_module.global_variables())
+ K._initialize_variables(sess)
# pylint: enable=protected-access
saver = saver_lib.Saver()
latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 62433a400b..b52ab7f05c 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -696,10 +696,9 @@ def _get_variables(graph=None):
return variables
-def _initialize_variables(session, variables=None):
+def _initialize_variables(session):
"""Utility to initialize uninitialized variables on the fly."""
- if variables is None:
- variables = _get_variables(ops.get_default_graph())
+ variables = _get_variables(ops.get_default_graph())
candidate_vars = []
for v in variables:
if not getattr(v, '_keras_initialized', False):
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index e3032acbfd..6bc256d2ec 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -447,6 +447,7 @@ def clone_and_build_model(
elif model.optimizer:
if isinstance(model.optimizer, optimizers.TFOptimizer):
optimizer = model.optimizer
+ K.track_tf_optimizer(optimizer)
else:
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 1b6bce2865..2304a461c1 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -772,16 +772,15 @@ class Optimizer(
Returns:
A list of variables.
"""
- executing_eagerly = context.executing_eagerly()
current_graph = ops.get_default_graph()
def _from_current_graph(variable):
- if executing_eagerly:
+ if variable._in_graph_mode: # pylint: disable=protected-access
+ return variable.op.graph is current_graph
+ else:
# No variable.op in eager mode. We don't expect lots of eager graphs,
# but behavior should be consistent with graph mode.
return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access
- else:
- return variable.op.graph is current_graph
optimizer_variables = [v for v in self._non_slot_variables()
if _from_current_graph(v)]