aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-11 10:00:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-11 10:04:13 -0700
commitd58f2b50b66d555790de51d5036320949101afa1 (patch)
treecb6d59884aab90648cab0e5f03cef8bfec52afce /tensorflow/contrib/seq2seq
parent0c0ee52e7841f7d14b4c8465a5825aaa2fef0fdb (diff)
Improve errors raised when an object does not match the RNNCell interface.
PiperOrigin-RevId: 188651070
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py4
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/basic_decoder.py3
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py3
3 files changed, 3 insertions, 7 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 0a53fd66db..f8da5a3e17 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -1152,9 +1152,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
is a list, and its length does not match that of `attention_layer_size`.
"""
super(AttentionWrapper, self).__init__(name=name)
- if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
- raise TypeError(
- "cell must be an RNNCell, saw type: %s" % type(cell).__name__)
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if isinstance(attention_mechanism, (list, tuple)):
self._is_multi = True
attention_mechanisms = attention_mechanism
diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py
index ed226239b8..7eb95e5a70 100644
--- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py
@@ -59,8 +59,7 @@ class BasicDecoder(decoder.Decoder):
Raises:
TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
"""
- if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
- raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if not isinstance(helper, helper_py.Helper):
raise TypeError("helper must be a Helper, received: %s" % type(helper))
if (output_layer is not None
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index d6184d6109..22dc7f2eda 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -195,8 +195,7 @@ class BeamSearchDecoder(decoder.Decoder):
ValueError: If `start_tokens` is not a vector or
`end_token` is not a scalar.
"""
- if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
- raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
+ rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access
if (output_layer is not None and
not isinstance(output_layer, layers_base.Layer)):
raise TypeError(