diff options
author | Neal Wu <wun@google.com> | 2017-03-23 15:26:59 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-23 16:46:46 -0700 |
commit | 5c5b6643848b04cf636c71e110e817748c1a9702 (patch) | |
tree | 4611394655da71c2599de7a79b9f98b608e8fa61 /tensorflow/contrib/legacy_seq2seq | |
parent | 63c115860cb0628cf5dc8c108682589774807653 (diff) |
softmax_loss_function in tf.contrib.legacy_seq2seq and tf.contrib.seq2seq now requires named arguments in order to prevent mixing up labels and logits
Also added lint fixes and fixed a broken/hidden test that wasn't being run
Change: 151069880
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py | 68 | ||||
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py | 18 |
2 files changed, 46 insertions, 40 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py index d71cfd4669..8dcfb775b2 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py @@ -455,37 +455,37 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0][1].c.shape) self.assertEqual((2, 2), res[0][1].h.shape) - def testDynamicAttentionDecoderStateIsTuple(self): - with self.test_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda - cells=[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)]) - cell = cell_fn() - inp = constant_op.constant(0.5, shape=[2, 2, 2]) - enc_outputs, enc_state = core_rnn.static_rnn( - cell, inp, dtype=dtypes.float32) - attn_states = array_ops.concat([ - array_ops.reshape(e, [-1, 1, cell.output_size]) - for e in enc_outputs - ], 1) - dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 - - # Use a new cell instance since the attention decoder uses a - # different variable scope. - dec, mem = seq2seq_lib.attention_decoder( - dec_inp, enc_state, attn_states, cell_fn(), output_size=4) - sess.run([variables.global_variables_initializer()]) - res = sess.run(dec) - self.assertEqual(3, len(res)) - self.assertEqual((2, 4), res[0].shape) + def testDynamicAttentionDecoderStateIsTuple(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda + cells=[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)]) + cell = cell_fn() + inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 + enc_outputs, enc_state = core_rnn.static_rnn( + cell, inp, dtype=dtypes.float32) + attn_states = array_ops.concat([ + array_ops.reshape(e, [-1, 1, cell.output_size]) + for e in enc_outputs + ], 1) + dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 - res = sess.run([mem]) - self.assertEqual(2, len(res[0])) - self.assertEqual((2, 2), res[0][0].c.shape) - self.assertEqual((2, 2), res[0][0].h.shape) - self.assertEqual((2, 2), res[0][1].c.shape) - self.assertEqual((2, 2), res[0][1].h.shape) + # Use a new cell instance since the attention decoder uses a + # different variable scope. + dec, mem = seq2seq_lib.attention_decoder( + dec_inp, enc_state, attn_states, cell_fn(), output_size=4) + sess.run([variables.global_variables_initializer()]) + res = sess.run(dec) + self.assertEqual(3, len(res)) + self.assertEqual((2, 4), res[0].shape) + + res = sess.run([mem]) + self.assertEqual(2, len(res[0])) + self.assertEqual((2, 2), res[0][0].c.shape) + self.assertEqual((2, 2), res[0][0].h.shape) + self.assertEqual((2, 2), res[0][1].c.shape) + self.assertEqual((2, 2), res[0][1].h.shape) def testEmbeddingAttentionDecoder(self): with self.test_session() as sess: @@ -876,13 +876,13 @@ class Seq2SeqTest(test.TestCase): targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0] - def SampledLoss(labels, inputs): + def SampledLoss(labels, logits): labels = array_ops.reshape(labels, [-1, 1]) return nn_impl.sampled_softmax_loss( weights=w_t, biases=b, labels=labels, - inputs=inputs, + inputs=logits, num_sampled=8, num_classes=classes) @@ -912,7 +912,7 @@ class Seq2SeqTest(test.TestCase): with variable_scope.variable_scope("root"): _, losses = SampleGRUSeq2Seq(inp, out, weights) updates = [] - params = variables.all_variables() + params = variables.global_variables() optimizer = adam.AdamOptimizer(0.03, epsilon=1e-5) for i in range(len(buckets)): full_grads = gradients_impl.gradients(losses[i], params) @@ -1007,7 +1007,7 @@ class Seq2SeqTest(test.TestCase): dec_op_fp_true, update_fp_true, variables_fp_true = ForwardBackward( enc_inp, dec_inp_fp_true, feed_previous=True) - dec_op_fp_false, update_fp_false, variables_fp_false = ForwardBackward( + _, update_fp_false, variables_fp_false = ForwardBackward( enc_inp, dec_inp_holder_fp_false, feed_previous=False) sess.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index 1202f961cc..a80b898156 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -381,7 +381,7 @@ def embedding_rnn_seq2seq(encoder_inputs, def decoder(feed_previous_bool): reuse = None if feed_previous_bool else True with variable_scope.variable_scope( - variable_scope.get_variable_scope(), reuse=reuse) as scope: + variable_scope.get_variable_scope(), reuse=reuse): outputs, state = embedding_rnn_decoder( decoder_inputs, encoder_state, @@ -884,7 +884,7 @@ def embedding_attention_seq2seq(encoder_inputs, def decoder(feed_previous_bool): reuse = None if feed_previous_bool else True with variable_scope.variable_scope( - variable_scope.get_variable_scope(), reuse=reuse) as scope: + variable_scope.get_variable_scope(), reuse=reuse): outputs, state = embedding_attention_decoder( decoder_inputs, encoder_state, @@ -1060,8 +1060,10 @@ def sequence_loss_by_example(logits, weights: List of 1D batch-sized float-Tensors of the same length as logits. average_across_timesteps: If set, divide the returned cost by the total label weight. - softmax_loss_function: Function (labels-batch, inputs-batch) -> loss-batch + softmax_loss_function: Function (labels, logits) -> loss-batch to be used instead of the standard softmax (the default if this is None). + **Note that to avoid confusion, it is required for the function to accept + named arguments.** name: Optional name for this operation, default: "sequence_loss_by_example". Returns: @@ -1085,7 +1087,7 @@ def sequence_loss_by_example(logits, crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( labels=target, logits=logit) else: - crossent = softmax_loss_function(target, logit) + crossent = softmax_loss_function(labels=target, logits=logit) log_perp_list.append(crossent * weight) log_perps = math_ops.add_n(log_perp_list) if average_across_timesteps: @@ -1111,8 +1113,10 @@ def sequence_loss(logits, average_across_timesteps: If set, divide the returned cost by the total label weight. average_across_batch: If set, divide the returned cost by the batch size. - softmax_loss_function: Function (labels-batch, inputs-batch) -> loss-batch + softmax_loss_function: Function (labels, logits) -> loss-batch to be used instead of the standard softmax (the default if this is None). + **Note that to avoid confusion, it is required for the function to accept + named arguments.** name: Optional name for this operation, defaults to "sequence_loss". Returns: @@ -1160,8 +1164,10 @@ def model_with_buckets(encoder_inputs, seq2seq: A sequence-to-sequence model function; it takes 2 input that agree with encoder_inputs and decoder_inputs, and returns a pair consisting of outputs and states (as, e.g., basic_rnn_seq2seq). - softmax_loss_function: Function (labels-batch, inputs-batch) -> loss-batch + softmax_loss_function: Function (labels, logits) -> loss-batch to be used instead of the standard softmax (the default if this is None). + **Note that to avoid confusion, it is required for the function to accept + named arguments.** per_example_loss: Boolean. If set, the returned loss will be a batch-sized tensor of losses for each sequence in the batch. If unset, it will be a scalar with the averaged loss from all examples. |