aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-03-27 13:48:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-27 15:08:22 -0700
commit8cc451c2a599ab0ce127875832c20e0155553a9d (patch)
tree25c1aa071d606fbf38f7ad8239b4a10f3384d9f2
parent4eb07180f4a0fce2cb265010233bace014e2c026 (diff)
[contrib seq2seq] Rename DynamicAttentionWrapper to AttentionWrapper
Change: 151374177
-rw-r--r--tensorflow/contrib/seq2seq/BUILD4
-rw-r--r--tensorflow/contrib/seq2seq/__init__.py6
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py (renamed from tensorflow/contrib/seq2seq/python/kernel_tests/dynamic_attention_wrapper_test.py)20
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py (renamed from tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py)24
4 files changed, 27 insertions, 27 deletions
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD
index 0621905705..652bbba85e 100644
--- a/tensorflow/contrib/seq2seq/BUILD
+++ b/tensorflow/contrib/seq2seq/BUILD
@@ -88,9 +88,9 @@ cuda_py_test(
)
cuda_py_test(
- name = "dynamic_attention_wrapper_test",
+ name = "attention_wrapper_test",
size = "medium",
- srcs = ["python/kernel_tests/dynamic_attention_wrapper_test.py"],
+ srcs = ["python/kernel_tests/attention_wrapper_test.py"],
additional_deps = [
":seq2seq_py",
"//tensorflow/contrib/layers:layers_py",
diff --git a/tensorflow/contrib/seq2seq/__init__.py b/tensorflow/contrib/seq2seq/__init__.py
index 29bce7bbae..277434c160 100644
--- a/tensorflow/contrib/seq2seq/__init__.py
+++ b/tensorflow/contrib/seq2seq/__init__.py
@@ -35,8 +35,8 @@ See the @{$python/contrib.seq2seq} guide.
@@hardmax
-@@DynamicAttentionWrapperState
-@@DynamicAttentionWrapper
+@@AttentionWrapperState
+@@AttentionWrapper
"""
from __future__ import absolute_import
@@ -44,9 +44,9 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
+from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import *
from tensorflow.contrib.seq2seq.python.ops.basic_decoder import *
from tensorflow.contrib.seq2seq.python.ops.decoder import *
-from tensorflow.contrib.seq2seq.python.ops.dynamic_attention_wrapper import *
from tensorflow.contrib.seq2seq.python.ops.helper import *
from tensorflow.contrib.seq2seq.python.ops.loss import *
# pylint: enable=unused-import,widcard-import,line-too-long
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/dynamic_attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 47cf102f39..5d952e97b7 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/dynamic_attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for contrib.seq2seq.python.ops.dynamic_attention_wrapper."""
+"""Tests for contrib.seq2seq.python.ops.attention_wrapper."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
@@ -25,7 +25,7 @@ import numpy as np
from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import decoder
-from tensorflow.contrib.seq2seq.python.ops import dynamic_attention_wrapper as wrapper
+from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import dtypes
@@ -38,12 +38,12 @@ from tensorflow.python.util import nest
# pylint: enable=g-import-not-at-top
-class DynamicAttentionWrapperTest(test.TestCase):
+class AttentionWrapperTest(test.TestCase):
def assertAllClose(self, *args, **kwargs):
kwargs["atol"] = 1e-4 # For GPU tests
kwargs["rtol"] = 1e-4 # For GPU tests
- return super(DynamicAttentionWrapperTest, self).assertAllClose(
+ return super(AttentionWrapperTest, self).assertAllClose(
*args, **kwargs)
def _testWithAttention(self,
@@ -76,7 +76,7 @@ class DynamicAttentionWrapperTest(test.TestCase):
"root",
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
cell = core_rnn_cell.LSTMCell(cell_depth)
- cell = wrapper.DynamicAttentionWrapper(
+ cell = wrapper.AttentionWrapper(
cell, attention_mechanism, attention_size=attention_depth)
helper = helper_py.TrainingHelper(decoder_inputs,
decoder_sequence_length)
@@ -91,7 +91,7 @@ class DynamicAttentionWrapperTest(test.TestCase):
self.assertTrue(
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
self.assertTrue(
- isinstance(final_state, wrapper.DynamicAttentionWrapperState))
+ isinstance(final_state, wrapper.AttentionWrapperState))
self.assertTrue(
isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple))
@@ -178,7 +178,7 @@ class DynamicAttentionWrapperTest(test.TestCase):
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
dtype=int32))
- expected_final_state = wrapper.DynamicAttentionWrapperState(
+ expected_final_state = wrapper.AttentionWrapperState(
cell_state=core_rnn_cell.LSTMStateTuple(
c=array(
[[
@@ -306,7 +306,7 @@ class DynamicAttentionWrapperTest(test.TestCase):
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
dtype=int32))
- expected_final_state = wrapper.DynamicAttentionWrapperState(
+ expected_final_state = wrapper.AttentionWrapperState(
cell_state=core_rnn_cell.LSTMStateTuple(
c=array(
[[
@@ -434,7 +434,7 @@ class DynamicAttentionWrapperTest(test.TestCase):
sample_id=array(
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
dtype=int32))
- expected_final_state = wrapper.DynamicAttentionWrapperState(
+ expected_final_state = wrapper.AttentionWrapperState(
cell_state=core_rnn_cell.LSTMStateTuple(
c=array(
[[
@@ -566,7 +566,7 @@ class DynamicAttentionWrapperTest(test.TestCase):
sample_id=array(
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
dtype=int32))
- expected_final_state = wrapper.DynamicAttentionWrapperState(
+ expected_final_state = wrapper.AttentionWrapperState(
cell_state=core_rnn_cell.LSTMStateTuple(
c=array(
[[
diff --git a/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 9fb67ed175..0af001d274 100644
--- a/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -37,8 +37,8 @@ from tensorflow.python.util import nest
__all__ = [
- "DynamicAttentionWrapper",
- "DynamicAttentionWrapperState",
+ "AttentionWrapper",
+ "AttentionWrapperState",
"LuongAttention",
"BahdanauAttention",
"hardmax",
@@ -376,10 +376,10 @@ class BahdanauAttention(_BaseAttentionMechanism):
return score
-class DynamicAttentionWrapperState(
+class AttentionWrapperState(
collections.namedtuple(
- "DynamicAttentionWrapperState", ("cell_state", "attention"))):
- """`namedtuple` storing the state of a `DynamicAttentionWrapper`.
+ "AttentionWrapperState", ("cell_state", "attention"))):
+ """`namedtuple` storing the state of a `AttentionWrapper`.
Contains:
@@ -410,7 +410,7 @@ def hardmax(logits, name=None):
math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
-class DynamicAttentionWrapper(core_rnn_cell.RNNCell):
+class AttentionWrapper(core_rnn_cell.RNNCell):
"""Wraps another `RNNCell` with attention.
"""
@@ -422,7 +422,7 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell):
probability_fn=None,
output_attention=True,
name=None):
- """Construct the `DynamicAttentionWrapper`.
+ """Construct the `AttentionWrapper`.
Args:
cell: An instance of `RNNCell`.
@@ -484,13 +484,13 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell):
@property
def state_size(self):
- return DynamicAttentionWrapperState(
+ return AttentionWrapperState(
cell_state=self._cell.state_size,
attention=self._attention_size)
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
- return DynamicAttentionWrapperState(
+ return AttentionWrapperState(
cell_state=self._cell.zero_state(batch_size, dtype),
attention=_zero_state_tensors(
self._attention_size, batch_size, dtype))
@@ -511,7 +511,7 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell):
Args:
inputs: (Possibly nested tuple of) Tensor, the input at this time step.
- state: An instance of `DynamicAttentionWrapperState` containing
+ state: An instance of `AttentionWrapperState` containing
tensors from the previous time step.
scope: Must be `None`.
@@ -519,7 +519,7 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell):
A tuple `(attention, next_state)`, where:
- `attention` is the attention passed to the layer above.
- - `next_state` is an instance of `DynamicAttentionWrapperState`
+ - `next_state` is an instance of `AttentionWrapperState`
containing the state calculated at this time step.
Raises:
@@ -555,7 +555,7 @@ class DynamicAttentionWrapper(core_rnn_cell.RNNCell):
attention = self._attention_layer(
array_ops.concat([cell_output, context], 1))
- next_state = DynamicAttentionWrapperState(
+ next_state = AttentionWrapperState(
cell_state=next_cell_state,
attention=attention)