aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-03-26 15:39:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 15:42:46 -0700
commitc83a54adcface7d4bb666d7c4fd3968ba980a50d (patch)
tree5010155c5e2a46ef47bb9f9933e3bbf0b4628a7e /tensorflow/contrib/seq2seq
parent290632966fae0619db30c1ba777634db9a43b757 (diff)
Makes tf.gather not silently snapshot resource variables.
PiperOrigin-RevId: 190537320
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py29
1 files changed, 17 insertions, 12 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index c4139dde49..07b3ad71d4 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -785,26 +785,31 @@ class AttentionWrapperTest(test.TestCase):
wrapper.BahdanauAttention, wrapper.LuongAttention)
expected_final_output = BasicDecoderOutput(
- rnn_output=ResultSummary(
- shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11798714846372604),
- sample_id=ResultSummary(
- shape=(5, 3), dtype=dtype('int32'), mean=7.933333333333334))
+ rnn_output=ResultSummary(shape=(5, 3, 20),
+ dtype=dtype('float32'),
+ mean=0.11723966),
+ sample_id=ResultSummary(shape=(5, 3),
+ dtype=dtype('int32'),
+ mean=9.2666666666666675))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
- c=ResultSummary(
- shape=(5, 9), dtype=dtype('float32'), mean=-0.0036486709),
- h=ResultSummary(
- shape=(5, 9), dtype=dtype('float32'), mean=-0.0018835809)),
- attention=ResultSummary(
- shape=(5, 20), dtype=dtype('float32'), mean=0.11798714846372604),
+ c=ResultSummary(shape=(5, 9),
+ dtype=dtype('float32'),
+ mean=-0.003545674),
+ h=ResultSummary(shape=(5, 9),
+ dtype=dtype('float32'),
+ mean=-0.0018327223)),
+ attention=ResultSummary(shape=(5, 20),
+ dtype=dtype('float32'),
+ mean=0.11728073),
time=3,
alignments=(
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ alignment_history=(),
attention_state=(
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
- alignment_history=())
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
expected_final_alignment_history = (
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))