aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-19 19:39:04 -0700
committerGravatar Yifei Feng <yifeif@google.com>2018-04-19 19:39:04 -0700
commitb1f78498659bd17c2d897a224eb19476d1efce60 (patch)
treedccb2ba37396e2980c098e9d2ea4bc1668ca6f9c /tensorflow/contrib/seq2seq
parenta734919fd8fd6d74edf1e7c3abec3ee11fec83fd (diff)
parentb001827146ff95c9e0ce5668c85d8cc2daf6b78d (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 6781433a1f..cd162bae25 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -411,11 +411,11 @@ class AttentionWrapperTest(test.TestCase):
def testLuongScaledDType(self):
# Test case for GitHub issue 18099
- for dtype in [np.float16, np.float32, np.float64]:
+ for dt in [np.float16, np.float32, np.float64]:
num_units = 128
- encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
+ encoder_outputs = array_ops.placeholder(dt, shape=[64, None, 256])
encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
- decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
+ decoder_inputs = array_ops.placeholder(dt, shape=[64, None, 128])
decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
batch_size = 64
attention_mechanism = wrapper.LuongAttention(
@@ -423,7 +423,7 @@ class AttentionWrapperTest(test.TestCase):
memory=encoder_outputs,
memory_sequence_length=encoder_sequence_length,
scale=True,
- dtype=dtype,
+ dtype=dt,
)
cell = rnn_cell.LSTMCell(num_units)
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
@@ -434,12 +434,12 @@ class AttentionWrapperTest(test.TestCase):
cell=cell,
helper=helper,
initial_state=cell.zero_state(
- dtype=dtype, batch_size=batch_size))
+ dtype=dt, batch_size=batch_size))
final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
self.assertTrue(
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
- self.assertEqual(final_outputs.rnn_output.dtype, dtype)
+ self.assertEqual(final_outputs.rnn_output.dtype, dt)
self.assertTrue(
isinstance(final_state, wrapper.AttentionWrapperState))
self.assertTrue(