diff options
author | 2018-09-17 13:31:40 -0700 | |
---|---|---|
committer | 2018-09-17 13:43:19 -0700 | |
commit | cd767b617ab00ffba993d62e4ff1f2028791fe4e (patch) | |
tree | b1cbe3be9517f1ebdb80211fc8eca033362cc227 /tensorflow/contrib/seq2seq | |
parent | 32ed8d488ad8088b63f046cde0c665e3b2aab8e7 (diff) |
Compute `axes` and `free` statically during graph creation.
PiperOrigin-RevId: 213327709
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py | 39 |
1 files changed, 18 insertions, 21 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index f2c43f30d4..1f3b533de9 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -919,31 +919,28 @@ class AttentionWrapperTest(test.TestCase): wrapper.BahdanauAttention, wrapper.LuongAttention) expected_final_output = BasicDecoderOutput( - rnn_output=ResultSummary(shape=(5, 3, 20), - dtype=dtype('float32'), - mean=0.11723966), - sample_id=ResultSummary(shape=(5, 3), - dtype=dtype('int32'), - mean=9.2666666666666675)) + rnn_output=ResultSummary( + shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11723966), + sample_id=ResultSummary( + shape=(5, 3), dtype=dtype('int32'), mean=7.266666666666667)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( - c=ResultSummary(shape=(5, 9), - dtype=dtype('float32'), - mean=-0.003545674), - h=ResultSummary(shape=(5, 9), - dtype=dtype('float32'), - mean=-0.0018327223)), - attention=ResultSummary(shape=(5, 20), - dtype=dtype('float32'), - mean=0.11728073), + c=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.003545674), + h=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0018327223)), + attention=ResultSummary( + shape=(5, 20), dtype=dtype('float32'), mean=0.11601614207), time=3, - alignments=( - ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), - ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), + alignments=(ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125)), alignment_history=(), - attention_state=( - ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), - ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125))) + attention_state=(ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125))) expected_final_alignment_history = ( ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125), ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125)) |