diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-02 12:28:57 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-02 12:33:07 -0800 |
commit | 138ce5760011b862375fb2ac750de4493a7c1919 (patch) | |
tree | 3acef21b53aa2e2b38de85e5a2d2adfebb9e4faf | |
parent | a071dd520e8a32f2ee4a585905d635503e596546 (diff) |
Internal cleanup.
PiperOrigin-RevId: 180578376
-rw-r--r-- | tensorflow/python/ops/rnn_cell_impl.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 7cb9f7762d..b41aff76d4 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -238,7 +238,8 @@ class RNNCell(base_layer.Layer): # Try to use the last cached zero_state. This is done to avoid recreating # zeros, especially when eager execution is enabled. state_size = self.state_size - if hasattr(self, "_last_zero_state"): + is_eager = context.in_eager_mode() + if is_eager and hasattr(self, "_last_zero_state"): (last_state_size, last_batch_size, last_dtype, last_output) = getattr(self, "_last_zero_state") if (last_batch_size == batch_size and @@ -247,7 +248,8 @@ class RNNCell(base_layer.Layer): return last_output with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): output = _zero_state_tensors(state_size, batch_size, dtype) - self._last_zero_state = (state_size, batch_size, dtype, output) + if is_eager: + self._last_zero_state = (state_size, batch_size, dtype, output) return output |