aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-24 11:21:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 11:33:50 -0700
commit5fbb064ba1e78bb28f7adbe92e6583c3b2bdfda7 (patch)
treef6ebb04cd6587f830aaf7382f29c994fcbf01adb /tensorflow/python/keras
parent834ad88d20a9dbdbe7552ecd8c2ec7c26b444ef2 (diff)
This CL adds an init_scope to the Keras set & get learning phase functions. This allows the Keras learning phase to work inside functions and defuns.
Note: There might still be bugs in graph mode if the default placeholder is being fed (instead of using set_learning_phase) and a layer is in a function. PiperOrigin-RevId: 214299002
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/backend.py48
1 files changed, 29 insertions, 19 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 60ed8e8c8a..a46f9edb1e 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -367,18 +367,26 @@ def learning_phase():
Returns:
Learning phase (scalar integer tensor or Python integer).
"""
- if context.executing_eagerly():
- if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
- # Fallback to inference mode as default.
- return 0
- return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
+ with ops.init_scope():
+ # We always check & set the learning phase inside the init_scope,
+ # otherwise the wrong default_graph will be used to look up the learning
+ # phase inside of functions & defuns.
+ #
+ # This is because functions & defuns (both in graph & in eager mode)
+ # will always execute non-eagerly using a function-specific default
+ # subgraph.
+ if context.executing_eagerly():
+ if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
+ # Fallback to inference mode as default.
+ return 0
+ return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
- graph = ops.get_default_graph()
- if graph not in _GRAPH_LEARNING_PHASES:
- phase = array_ops.placeholder_with_default(
- False, shape=(), name='keras_learning_phase')
- _GRAPH_LEARNING_PHASES[graph] = phase
- return _GRAPH_LEARNING_PHASES[graph]
+ graph = ops.get_default_graph()
+ if graph not in _GRAPH_LEARNING_PHASES:
+ phase = array_ops.placeholder_with_default(
+ False, shape=(), name='keras_learning_phase')
+ _GRAPH_LEARNING_PHASES[graph] = phase
+ return _GRAPH_LEARNING_PHASES[graph]
@tf_export('keras.backend.set_learning_phase')
@@ -394,10 +402,11 @@ def set_learning_phase(value):
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
if value not in {0, 1}:
raise ValueError('Expected learning phase to be 0 or 1.')
- if context.executing_eagerly():
- _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
- else:
- _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
+ with ops.init_scope():
+ if context.executing_eagerly():
+ _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
+ else:
+ _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
@tf_contextlib.contextmanager
@@ -423,10 +432,11 @@ def learning_phase_scope(value):
yield value
finally:
# Restore learning phase to initial value.
- if context.executing_eagerly():
- _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
- else:
- _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
+ with ops.init_scope():
+ if context.executing_eagerly():
+ _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
+ else:
+ _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
@tf_export('keras.backend.get_session')