aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-02 12:28:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-02 12:33:07 -0800
commit138ce5760011b862375fb2ac750de4493a7c1919 (patch)
tree3acef21b53aa2e2b38de85e5a2d2adfebb9e4faf
parenta071dd520e8a32f2ee4a585905d635503e596546 (diff)
Internal cleanup.
PiperOrigin-RevId: 180578376
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 7cb9f7762d..b41aff76d4 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -238,7 +238,8 @@ class RNNCell(base_layer.Layer):
# Try to use the last cached zero_state. This is done to avoid recreating
# zeros, especially when eager execution is enabled.
state_size = self.state_size
- if hasattr(self, "_last_zero_state"):
+ is_eager = context.in_eager_mode()
+ if is_eager and hasattr(self, "_last_zero_state"):
(last_state_size, last_batch_size, last_dtype,
last_output) = getattr(self, "_last_zero_state")
if (last_batch_size == batch_size and
@@ -247,7 +248,8 @@ class RNNCell(base_layer.Layer):
return last_output
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
output = _zero_state_tensors(state_size, batch_size, dtype)
- self._last_zero_state = (state_size, batch_size, dtype, output)
+ if is_eager:
+ self._last_zero_state = (state_size, batch_size, dtype, output)
return output