aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/_impl/keras/layers/recurrent.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/_impl/keras/layers/recurrent.py')
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index b34b92c763..2e9003f52d 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numbers
import numpy as np
from tensorflow.python.framework import tensor_shape
@@ -413,7 +414,7 @@ class RNN(Layer):
@property
def states(self):
if self._states is None:
- if isinstance(self.cell.state_size, int):
+ if isinstance(self.cell.state_size, numbers.Integral):
num_states = 1
else:
num_states = len(self.cell.state_size)