aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2016-12-08 01:16:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-08 01:24:30 -0800
commit7b3f3abd2e8c6b51c0a256f181225e007fc3f76d (patch)
tree8f8acd9182e0e2be473db80cda5b78b0d7cf7068 /tensorflow/contrib/legacy_seq2seq
parent202f77420beb8a556c6a3d685e8f5daab0a8275c (diff)
Prepare all callers of sampled_softmax_loss to be ready for inputs and labels swap.
Change: 141410359
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 403d0b2661..c908fbc655 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -611,7 +611,13 @@ class Seq2SeqTest(tf.test.TestCase):
targets = [dec_inp[i+1] for i in range(len(dec_inp) - 1)] + [0]
def SampledLoss(labels, inputs):
labels = tf.reshape(labels, [-1, 1])
- return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, 8, classes)
+ return tf.nn.sampled_softmax_loss(
+ weights=w_t,
+ biases=b,
+ labels=labels,
+ inputs=inputs,
+ num_sampled=8,
+ num_classes=classes)
return tf.contrib.legacy_seq2seq.model_with_buckets(
enc_inp, dec_inp, targets, weights, buckets, GRUSeq2Seq,
softmax_loss_function=SampledLoss)