diff options
author | 2016-12-08 01:16:07 -0800 | |
---|---|---|
committer | 2016-12-08 01:24:30 -0800 | |
commit | 7b3f3abd2e8c6b51c0a256f181225e007fc3f76d (patch) | |
tree | 8f8acd9182e0e2be473db80cda5b78b0d7cf7068 /tensorflow/contrib/legacy_seq2seq | |
parent | 202f77420beb8a556c6a3d685e8f5daab0a8275c (diff) |
Prepare all callers of sampled_softmax_loss to be ready for inputs and labels swap.
Change: 141410359
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py index 403d0b2661..c908fbc655 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py @@ -611,7 +611,13 @@ class Seq2SeqTest(tf.test.TestCase): targets = [dec_inp[i+1] for i in range(len(dec_inp) - 1)] + [0] def SampledLoss(labels, inputs): labels = tf.reshape(labels, [-1, 1]) - return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, 8, classes) + return tf.nn.sampled_softmax_loss( + weights=w_t, + biases=b, + labels=labels, + inputs=inputs, + num_sampled=8, + num_classes=classes) return tf.contrib.legacy_seq2seq.model_with_buckets( enc_inp, dec_inp, targets, weights, buckets, GRUSeq2Seq, softmax_loss_function=SampledLoss) |