aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-30 11:23:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-30 11:30:04 -0700
commit367ec84f8c0d79cf489dc9dbf20a3c56f57f5f09 (patch)
tree00b9ffd18936bff82437a3502b0c82fe827eea84 /tensorflow/contrib/seq2seq/python/kernel_tests
parenta3ba225d5b327013709a1732688bfd4346b3c86e (diff)
Add SampleEmbeddingHelper to do sampling at inference time
PiperOrigin-RevId: 157487623
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py72
1 files changed, 72 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py
index 600adea189..cb12bc9450 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py
@@ -27,8 +27,10 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import core as layers_core
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top
@@ -189,6 +191,76 @@ class BasicDecoderTest(test.TestCase):
self.assertAllEqual(expected_step_next_inputs,
sess_results["step_next_inputs"])
+ def testStepWithSampleEmbeddingHelper(self):
+ batch_size = 5
+ vocabulary_size = 7
+ cell_depth = vocabulary_size # cell's logits must match vocabulary size
+ input_depth = 10
+ np.random.seed(0)
+ start_tokens = np.random.randint(0, vocabulary_size, size=batch_size)
+ end_token = 1
+
+ with self.test_session(use_gpu=True) as sess:
+ with variable_scope.variable_scope(
+ "testStepWithSampleEmbeddingHelper",
+ initializer=init_ops.constant_initializer(0.01)):
+ embeddings = np.random.randn(vocabulary_size,
+ input_depth).astype(np.float32)
+ cell = rnn_cell.LSTMCell(vocabulary_size)
+ helper = helper_py.SampleEmbeddingHelper(embeddings, start_tokens,
+ end_token, seed=0)
+ my_decoder = basic_decoder.BasicDecoder(
+ cell=cell,
+ helper=helper,
+ initial_state=cell.zero_state(
+ dtype=dtypes.float32, batch_size=batch_size))
+ output_size = my_decoder.output_size
+ output_dtype = my_decoder.output_dtype
+ self.assertEqual(
+ basic_decoder.BasicDecoderOutput(cell_depth,
+ tensor_shape.TensorShape([])),
+ output_size)
+ self.assertEqual(
+ basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
+ output_dtype)
+
+ (first_finished, first_inputs, first_state) = my_decoder.initialize()
+ (step_outputs, step_state, step_next_inputs,
+ step_finished) = my_decoder.step(
+ constant_op.constant(0), first_inputs, first_state)
+ batch_size_t = my_decoder.batch_size
+
+ self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
+ self.assertTrue(
+ isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
+ self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
+ self.assertEqual((batch_size,), step_outputs[1].get_shape())
+ self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
+ self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
+ self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
+ self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())
+
+ sess.run(variables.global_variables_initializer())
+ sess_results = sess.run({
+ "batch_size": batch_size_t,
+ "first_finished": first_finished,
+ "first_inputs": first_inputs,
+ "first_state": first_state,
+ "step_outputs": step_outputs,
+ "step_state": step_state,
+ "step_next_inputs": step_next_inputs,
+ "step_finished": step_finished
+ })
+
+ sample_ids = sess_results["step_outputs"].sample_id
+ expected_step_finished = (sample_ids == end_token)
+ expected_step_next_inputs = embeddings[sample_ids]
+ self.assertAllEqual(expected_step_finished,
+ sess_results["step_finished"])
+ self.assertAllEqual(expected_step_next_inputs,
+ sess_results["step_next_inputs"])
+
def testStepWithScheduledEmbeddingTrainingHelper(self):
sequence_length = [3, 4, 3, 1, 0]
batch_size = 5