diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-03-27 13:48:35 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-27 15:08:22 -0700 |
commit | 8cc451c2a599ab0ce127875832c20e0155553a9d (patch) | |
tree | 25c1aa071d606fbf38f7ad8239b4a10f3384d9f2 | |
parent | 4eb07180f4a0fce2cb265010233bace014e2c026 (diff) |
[contrib seq2seq] Rename DynamicAttentionWrapper to AttentionWrapper
Change: 151374177
-rw-r--r-- | tensorflow/contrib/seq2seq/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/__init__.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py (renamed from tensorflow/contrib/seq2seq/python/kernel_tests/dynamic_attention_wrapper_test.py) | 20 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py (renamed from tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py) | 24 |
4 files changed, 27 insertions, 27 deletions
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 0621905705..652bbba85e 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -88,9 +88,9 @@ cuda_py_test( ) cuda_py_test( - name = "dynamic_attention_wrapper_test", + name = "attention_wrapper_test", size = "medium", - srcs = ["python/kernel_tests/dynamic_attention_wrapper_test.py"], + srcs = ["python/kernel_tests/attention_wrapper_test.py"], additional_deps = [ ":seq2seq_py", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/seq2seq/__init__.py b/tensorflow/contrib/seq2seq/__init__.py index 29bce7bbae..277434c160 100644 --- a/tensorflow/contrib/seq2seq/__init__.py +++ b/tensorflow/contrib/seq2seq/__init__.py @@ -35,8 +35,8 @@ See the @{$python/contrib.seq2seq} guide. @@hardmax -@@DynamicAttentionWrapperState -@@DynamicAttentionWrapper +@@AttentionWrapperState +@@AttentionWrapper """ from __future__ import absolute_import @@ -44,9 +44,9 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long +from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import * from tensorflow.contrib.seq2seq.python.ops.basic_decoder import * from tensorflow.contrib.seq2seq.python.ops.decoder import * -from tensorflow.contrib.seq2seq.python.ops.dynamic_attention_wrapper import * from tensorflow.contrib.seq2seq.python.ops.helper import * from tensorflow.contrib.seq2seq.python.ops.loss import * # pylint: enable=unused-import,widcard-import,line-too-long diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/dynamic_attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 47cf102f39..5d952e97b7 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/dynamic_attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for contrib.seq2seq.python.ops.dynamic_attention_wrapper.""" +"""Tests for contrib.seq2seq.python.ops.attention_wrapper.""" # pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division @@ -25,7 +25,7 @@ import numpy as np from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.contrib.seq2seq.python.ops import decoder -from tensorflow.contrib.seq2seq.python.ops import dynamic_attention_wrapper as wrapper +from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.python.framework import dtypes @@ -38,12 +38,12 @@ from tensorflow.python.util import nest # pylint: enable=g-import-not-at-top -class DynamicAttentionWrapperTest(test.TestCase): +class AttentionWrapperTest(test.TestCase): def assertAllClose(self, *args, **kwargs): kwargs["atol"] = 1e-4 # For GPU tests kwargs["rtol"] = 1e-4 # For GPU tests - return super(DynamicAttentionWrapperTest, self).assertAllClose( + return super(AttentionWrapperTest, self).assertAllClose( *args, **kwargs) def _testWithAttention(self, @@ -76,7 +76,7 @@ class DynamicAttentionWrapperTest(test.TestCase): "root", initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)): cell = core_rnn_cell.LSTMCell(cell_depth) - cell = wrapper.DynamicAttentionWrapper( + cell = wrapper.AttentionWrapper( cell, attention_mechanism, attention_size=attention_depth) helper = helper_py.TrainingHelper(decoder_inputs, decoder_sequence_length) @@ -91,7 +91,7 @@ class DynamicAttentionWrapperTest(test.TestCase): self.assertTrue( isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) self.assertTrue( - isinstance(final_state, wrapper.DynamicAttentionWrapperState)) + isinstance(final_state, wrapper.AttentionWrapperState)) self.assertTrue( isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple)) @@ -178,7 +178,7 @@ class DynamicAttentionWrapperTest(test.TestCase): [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int32)) - expected_final_state = wrapper.DynamicAttentionWrapperState( + expected_final_state = wrapper.AttentionWrapperState( cell_state=core_rnn_cell.LSTMStateTuple( c=array( [[ @@ -306,7 +306,7 @@ class DynamicAttentionWrapperTest(test.TestCase): [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int32)) - expected_final_state = wrapper.DynamicAttentionWrapperState( + expected_final_state = wrapper.AttentionWrapperState( cell_state=core_rnn_cell.LSTMStateTuple( c=array( [[ @@ -434,7 +434,7 @@ class DynamicAttentionWrapperTest(test.TestCase): sample_id=array( [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int32)) - expected_final_state = wrapper.DynamicAttentionWrapperState( + expected_final_state = wrapper.AttentionWrapperState( cell_state=core_rnn_cell.LSTMStateTuple( c=array( [[ @@ -566,7 +566,7 @@ class DynamicAttentionWrapperTest(test.TestCase): sample_id=array( [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int32)) - expected_final_state = wrapper.DynamicAttentionWrapperState( + expected_final_state = wrapper.AttentionWrapperState( cell_state=core_rnn_cell.LSTMStateTuple( c=array( [[ diff --git a/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 9fb67ed175..0af001d274 100644 --- a/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -37,8 +37,8 @@ from tensorflow.python.util import nest __all__ = [ - "DynamicAttentionWrapper", - "DynamicAttentionWrapperState", + "AttentionWrapper", + "AttentionWrapperState", "LuongAttention", "BahdanauAttention", "hardmax", @@ -376,10 +376,10 @@ class BahdanauAttention(_BaseAttentionMechanism): return score -class DynamicAttentionWrapperState( +class AttentionWrapperState( collections.namedtuple( - "DynamicAttentionWrapperState", ("cell_state", "attention"))): - """`namedtuple` storing the state of a `DynamicAttentionWrapper`. + "AttentionWrapperState", ("cell_state", "attention"))): + """`namedtuple` storing the state of a `AttentionWrapper`. Contains: @@ -410,7 +410,7 @@ def hardmax(logits, name=None): math_ops.argmax(logits, -1), depth, dtype=logits.dtype) -class DynamicAttentionWrapper(core_rnn_cell.RNNCell): +class AttentionWrapper(core_rnn_cell.RNNCell): """Wraps another `RNNCell` with attention. """ @@ -422,7 +422,7 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell): probability_fn=None, output_attention=True, name=None): - """Construct the `DynamicAttentionWrapper`. + """Construct the `AttentionWrapper`. Args: cell: An instance of `RNNCell`. @@ -484,13 +484,13 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell): @property def state_size(self): - return DynamicAttentionWrapperState( + return AttentionWrapperState( cell_state=self._cell.state_size, attention=self._attention_size) def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): - return DynamicAttentionWrapperState( + return AttentionWrapperState( cell_state=self._cell.zero_state(batch_size, dtype), attention=_zero_state_tensors( self._attention_size, batch_size, dtype)) @@ -511,7 +511,7 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell): Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. - state: An instance of `DynamicAttentionWrapperState` containing + state: An instance of `AttentionWrapperState` containing tensors from the previous time step. scope: Must be `None`. @@ -519,7 +519,7 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell): A tuple `(attention, next_state)`, where: - `attention` is the attention passed to the layer above. - - `next_state` is an instance of `DynamicAttentionWrapperState` + - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. Raises: @@ -555,7 +555,7 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell): attention = self._attention_layer( array_ops.concat([cell_output, context], 1)) - next_state = DynamicAttentionWrapperState( + next_state = AttentionWrapperState( cell_state=next_cell_state, attention=attention) |