diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-24 11:21:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 11:33:50 -0700 |
commit | 5fbb064ba1e78bb28f7adbe92e6583c3b2bdfda7 (patch) | |
tree | f6ebb04cd6587f830aaf7382f29c994fcbf01adb /tensorflow/python/keras | |
parent | 834ad88d20a9dbdbe7552ecd8c2ec7c26b444ef2 (diff) |
This CL adds an init_scope to the Keras set & get learning phase functions. This allows the Keras learning phase to work inside functions and defuns.
Note: There might still be bugs in graph mode if the default placeholder is being fed (instead of using set_learning_phase) and a layer is in a function.
PiperOrigin-RevId: 214299002
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/backend.py | 48 |
1 files changed, 29 insertions, 19 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 60ed8e8c8a..a46f9edb1e 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -367,18 +367,26 @@ def learning_phase(): Returns: Learning phase (scalar integer tensor or Python integer). """ - if context.executing_eagerly(): - if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES: - # Fallback to inference mode as default. - return 0 - return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] + with ops.init_scope(): + # We always check & set the learning phase inside the init_scope, + # otherwise the wrong default_graph will be used to look up the learning + # phase inside of functions & defuns. + # + # This is because functions & defuns (both in graph & in eager mode) + # will always execute non-eagerly using a function-specific default + # subgraph. + if context.executing_eagerly(): + if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES: + # Fallback to inference mode as default. + return 0 + return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] - graph = ops.get_default_graph() - if graph not in _GRAPH_LEARNING_PHASES: - phase = array_ops.placeholder_with_default( - False, shape=(), name='keras_learning_phase') - _GRAPH_LEARNING_PHASES[graph] = phase - return _GRAPH_LEARNING_PHASES[graph] + graph = ops.get_default_graph() + if graph not in _GRAPH_LEARNING_PHASES: + phase = array_ops.placeholder_with_default( + False, shape=(), name='keras_learning_phase') + _GRAPH_LEARNING_PHASES[graph] = phase + return _GRAPH_LEARNING_PHASES[graph] @tf_export('keras.backend.set_learning_phase') @@ -394,10 +402,11 @@ def set_learning_phase(value): global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned if value not in {0, 1}: raise ValueError('Expected learning phase to be 0 or 1.') - if context.executing_eagerly(): - _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value - else: - _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value + with ops.init_scope(): + if context.executing_eagerly(): + _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value + else: + _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value @tf_contextlib.contextmanager @@ -423,10 +432,11 @@ def learning_phase_scope(value): yield value finally: # Restore learning phase to initial value. - if context.executing_eagerly(): - _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value - else: - _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value + with ops.init_scope(): + if context.executing_eagerly(): + _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value + else: + _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value @tf_export('keras.backend.get_session') |