diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-11 10:00:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-11 10:04:13 -0700 |
commit | d58f2b50b66d555790de51d5036320949101afa1 (patch) | |
tree | cb6d59884aab90648cab0e5f03cef8bfec52afce /tensorflow/contrib/seq2seq | |
parent | 0c0ee52e7841f7d14b4c8465a5825aaa2fef0fdb (diff) |
Improve errors raised when an object does not match the RNNCell interface.
PiperOrigin-RevId: 188651070
Diffstat (limited to 'tensorflow/contrib/seq2seq')
3 files changed, 3 insertions, 7 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 0a53fd66db..f8da5a3e17 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -1152,9 +1152,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): is a list, and its length does not match that of `attention_layer_size`. """ super(AttentionWrapper, self).__init__(name=name) - if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access - raise TypeError( - "cell must be an RNNCell, saw type: %s" % type(cell).__name__) + rnn_cell_impl.assert_like_rnncell("cell", cell) if isinstance(attention_mechanism, (list, tuple)): self._is_multi = True attention_mechanisms = attention_mechanism diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index ed226239b8..7eb95e5a70 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -59,8 +59,7 @@ class BasicDecoder(decoder.Decoder): Raises: TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. """ - if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access - raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) + rnn_cell_impl.assert_like_rnncell("cell", cell) if not isinstance(helper, helper_py.Helper): raise TypeError("helper must be a Helper, received: %s" % type(helper)) if (output_layer is not None diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index d6184d6109..22dc7f2eda 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -195,8 +195,7 @@ class BeamSearchDecoder(decoder.Decoder): ValueError: If `start_tokens` is not a vector or `end_token` is not a scalar. """ - if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access - raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) + rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)): raise TypeError( |