diff options
author | 2017-05-30 11:23:21 -0700 | |
---|---|---|
committer | 2017-05-30 11:30:04 -0700 | |
commit | 367ec84f8c0d79cf489dc9dbf20a3c56f57f5f09 (patch) | |
tree | 00b9ffd18936bff82437a3502b0c82fe827eea84 /tensorflow/contrib/seq2seq/python/kernel_tests | |
parent | a3ba225d5b327013709a1732688bfd4346b3c86e (diff) |
Add SampleEmbeddingHelper to do sampling at inference time
PiperOrigin-RevId: 157487623
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index 600adea189..cb12bc9450 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -27,8 +27,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import init_ops from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test # pylint: enable=g-import-not-at-top @@ -189,6 +191,76 @@ class BasicDecoderTest(test.TestCase): self.assertAllEqual(expected_step_next_inputs, sess_results["step_next_inputs"]) + def testStepWithSampleEmbeddingHelper(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size # cell's logits must match vocabulary size + input_depth = 10 + np.random.seed(0) + start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) + end_token = 1 + + with self.test_session(use_gpu=True) as sess: + with variable_scope.variable_scope( + "testStepWithSampleEmbeddingHelper", + initializer=init_ops.constant_initializer(0.01)): + embeddings = np.random.randn(vocabulary_size, + input_depth).astype(np.float32) + cell = rnn_cell.LSTMCell(vocabulary_size) + helper = helper_py.SampleEmbeddingHelper(embeddings, start_tokens, + end_token, seed=0) + my_decoder = basic_decoder.BasicDecoder( + cell=cell, + helper=helper, + initial_state=cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size)) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (first_finished, first_inputs, first_state) = my_decoder.initialize() + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + sess.run(variables.global_variables_initializer()) + sess_results = sess.run({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = sess_results["step_outputs"].sample_id + expected_step_finished = (sample_ids == end_token) + expected_step_next_inputs = embeddings[sample_ids] + self.assertAllEqual(expected_step_finished, + sess_results["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + sess_results["step_next_inputs"]) + def testStepWithScheduledEmbeddingTrainingHelper(self): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 |