aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 11:59:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 12:09:21 -0700
commit2667ed3bf01e7153f466b27c450fc2b662c00bdd (patch)
treeacabe1fc78fe4bc215463def3ae0f815425c978a /tensorflow/python
parentb82c4dad705bffac6d14a189605c9ece89f8c17b (diff)
Makes sure Keras Layer's `__call__` is always used in Eager.
Currently if a Layer is invoked with the Functional API in Eager, `__call__` is only used during setup, and thereafter `call` is used internally. This limits the ability to add pre/post processing steps to `call` in Eager in the future. Additionally, the Subclassed Model API already always uses `__call__` in Eager. PiperOrigin-RevId: 215778408
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/keras/engine/network.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 8d34006967..918488bd7a 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1028,7 +1028,10 @@ class Network(base_layer.Layer):
output_tensors, output_masks = layer._call_and_compute_mask(
computed_tensor, **kwargs)
else:
- output_tensors = layer.call(computed_tensor, **kwargs)
+ if context.executing_eagerly():
+ output_tensors = layer(computed_tensor, **kwargs)
+ else:
+ output_tensors = layer.call(computed_tensor, **kwargs)
if hasattr(layer, 'compute_mask'):
output_masks = layer.compute_mask(computed_tensor,
computed_mask)
@@ -1049,7 +1052,10 @@ class Network(base_layer.Layer):
output_tensors, output_masks = layer._call_and_compute_mask(
computed_tensors, **kwargs)
else:
- output_tensors = layer.call(computed_tensors, **kwargs)
+ if context.executing_eagerly():
+ output_tensors = layer(computed_tensors, **kwargs)
+ else:
+ output_tensors = layer.call(computed_tensors, **kwargs)
if hasattr(layer, 'compute_mask'):
output_masks = layer.compute_mask(computed_tensors,
computed_masks)