diff options
author | 2018-04-19 19:39:04 -0700 | |
---|---|---|
committer | 2018-04-19 19:39:04 -0700 | |
commit | b1f78498659bd17c2d897a224eb19476d1efce60 (patch) | |
tree | dccb2ba37396e2980c098e9d2ea4bc1668ca6f9c /tensorflow/contrib/seq2seq | |
parent | a734919fd8fd6d74edf1e7c3abec3ee11fec83fd (diff) | |
parent | b001827146ff95c9e0ce5668c85d8cc2daf6b78d (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 6781433a1f..cd162bae25 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -411,11 +411,11 @@ class AttentionWrapperTest(test.TestCase): def testLuongScaledDType(self): # Test case for GitHub issue 18099 - for dtype in [np.float16, np.float32, np.float64]: + for dt in [np.float16, np.float32, np.float64]: num_units = 128 - encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256]) + encoder_outputs = array_ops.placeholder(dt, shape=[64, None, 256]) encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64]) - decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128]) + decoder_inputs = array_ops.placeholder(dt, shape=[64, None, 128]) decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64]) batch_size = 64 attention_mechanism = wrapper.LuongAttention( @@ -423,7 +423,7 @@ class AttentionWrapperTest(test.TestCase): memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, scale=True, - dtype=dtype, + dtype=dt, ) cell = rnn_cell.LSTMCell(num_units) cell = wrapper.AttentionWrapper(cell, attention_mechanism) @@ -434,12 +434,12 @@ class AttentionWrapperTest(test.TestCase): cell=cell, helper=helper, initial_state=cell.zero_state( - dtype=dtype, batch_size=batch_size)) + dtype=dt, batch_size=batch_size)) final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder) self.assertTrue( isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) - self.assertEqual(final_outputs.rnn_output.dtype, dtype) + self.assertEqual(final_outputs.rnn_output.dtype, dt) self.assertTrue( isinstance(final_state, wrapper.AttentionWrapperState)) self.assertTrue( |