aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-17 13:31:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 13:43:19 -0700
commitcd767b617ab00ffba993d62e4ff1f2028791fe4e (patch)
treeb1cbe3be9517f1ebdb80211fc8eca033362cc227 /tensorflow/contrib/seq2seq
parent32ed8d488ad8088b63f046cde0c665e3b2aab8e7 (diff)
Compute `axes` and `free` statically during graph creation.
PiperOrigin-RevId: 213327709
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py39
1 files changed, 18 insertions, 21 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index f2c43f30d4..1f3b533de9 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -919,31 +919,28 @@ class AttentionWrapperTest(test.TestCase):
wrapper.BahdanauAttention, wrapper.LuongAttention)
expected_final_output = BasicDecoderOutput(
- rnn_output=ResultSummary(shape=(5, 3, 20),
- dtype=dtype('float32'),
- mean=0.11723966),
- sample_id=ResultSummary(shape=(5, 3),
- dtype=dtype('int32'),
- mean=9.2666666666666675))
+ rnn_output=ResultSummary(
+ shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11723966),
+ sample_id=ResultSummary(
+ shape=(5, 3), dtype=dtype('int32'), mean=7.266666666666667))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
- c=ResultSummary(shape=(5, 9),
- dtype=dtype('float32'),
- mean=-0.003545674),
- h=ResultSummary(shape=(5, 9),
- dtype=dtype('float32'),
- mean=-0.0018327223)),
- attention=ResultSummary(shape=(5, 20),
- dtype=dtype('float32'),
- mean=0.11728073),
+ c=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.003545674),
+ h=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0018327223)),
+ attention=ResultSummary(
+ shape=(5, 20), dtype=dtype('float32'), mean=0.11601614207),
time=3,
- alignments=(
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ alignments=(ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
alignment_history=(),
- attention_state=(
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
+ attention_state=(ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
expected_final_alignment_history = (
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))