aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-11 10:00:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-11 10:04:13 -0700
commitd58f2b50b66d555790de51d5036320949101afa1 (patch)
treecb6d59884aab90648cab0e5f03cef8bfec52afce /tensorflow/contrib/rnn
parent0c0ee52e7841f7d14b4c8465a5825aaa2fef0fdb (diff)
Improve errors raised when an object does not match the RNNCell interface.
PiperOrigin-RevId: 188651070
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py8
-rw-r--r--tensorflow/contrib/rnn/python/ops/core_rnn_cell.py10
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py3
3 files changed, 8 insertions, 13 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 7de55a0bb3..69f7b8e107 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -455,8 +455,8 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testAttentionCellWrapperFailures(self):
- with self.assertRaisesRegexp(TypeError,
- "The parameter cell is not RNNCell."):
+ with self.assertRaisesRegexp(
+ TypeError, rnn_cell_impl.ASSERT_LIKE_RNNCELL_ERROR_REGEXP):
contrib_rnn_cell.AttentionCellWrapper(None, 0)
num_units = 8
@@ -1203,7 +1203,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
h1 = array_ops.zeros([1, 2])
state1 = rnn_cell.LSTMStateTuple(c1, h1)
state = (state0, state1)
- single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False)
+ single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False) # pylint: disable=line-too-long
cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
@@ -1235,7 +1235,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(expected_state1_h, actual_state1_h, 1e-5)
with variable_scope.variable_scope(
- "other", initializer=init_ops.constant_initializer(0.5)) as vs:
+ "other", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros(
[1, 3]) # Test BasicLSTMCell with input_size != num_units.
c = array_ops.zeros([1, 2])
diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py
index 8109ebc718..645f82624b 100644
--- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py
@@ -40,7 +40,6 @@ from tensorflow.python.util import nest
# pylint: disable=protected-access,invalid-name
RNNCell = rnn_cell_impl.RNNCell
-_like_rnncell = rnn_cell_impl._like_rnncell
_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME
_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME
# pylint: enable=protected-access,invalid-name
@@ -221,8 +220,7 @@ class EmbeddingWrapper(RNNCell):
ValueError: if embedding_classes is not positive.
"""
super(EmbeddingWrapper, self).__init__(_reuse=reuse)
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not RNNCell.")
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if embedding_classes <= 0 or embedding_size <= 0:
raise ValueError("Both embedding_classes and embedding_size must be > 0: "
"%d, %d." % (embedding_classes, embedding_size))
@@ -301,8 +299,7 @@ class InputProjectionWrapper(RNNCell):
super(InputProjectionWrapper, self).__init__(_reuse=reuse)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not RNNCell.")
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
self._cell = cell
self._num_proj = num_proj
self._activation = activation
@@ -356,8 +353,7 @@ class OutputProjectionWrapper(RNNCell):
ValueError: if output_size is not positive.
"""
super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not RNNCell.")
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if output_size < 1:
raise ValueError("Parameter output_size must be > 0: %d." % output_size)
self._cell = cell
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 6bea8d4a21..3028edad1b 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -1143,8 +1143,7 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
`state_is_tuple` is `False` or if attn_length is zero or less.
"""
super(AttentionCellWrapper, self).__init__(_reuse=reuse)
- if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
- raise TypeError("The parameter cell is not RNNCell.")
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if nest.is_sequence(cell.state_size) and not state_is_tuple:
raise ValueError(
"Cell returns tuple of states, but the flag "