aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-01-25 14:54:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-25 15:04:56 -0800
commit8291536b277379ff37dbb8d3198d947776aaff2f (patch)
treed0362b7aa3a204fbfb6438080c7740f934b35caa
parent239493a6825f33c96d64b6a36be6616fbb41e42b (diff)
Add ArgmaxEmbeddingInferenceSampler to new tf.contrib.seq2seq API
Also add a properly caching varscope to the decoder fn. Change: 145600420
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py9
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py71
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py171
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py84
4 files changed, 250 insertions, 85 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
index b3c6c593c5..f884a83095 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
@@ -125,16 +125,17 @@ class DynamicDecodeRNNTest(test.TestCase):
# Match the variable scope of dynamic_rnn below so we end up
# using the same variables
- with vs.variable_scope("rnn"):
+ with vs.variable_scope("root") as scope:
final_decoder_outputs, final_decoder_state = decoder.dynamic_decode_rnn(
- my_decoder)
+ my_decoder, scope=scope)
- with vs.variable_scope(vs.get_variable_scope(), reuse=True):
+ with vs.variable_scope(scope, reuse=True) as scope:
final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn(
cell,
inputs,
sequence_length=sequence_length,
- initial_state=zero_state)
+ initial_state=zero_state,
+ scope=scope)
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py
index ba945a0ecb..616fdcf554 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py
@@ -104,6 +104,77 @@ class BasicSamplingDecoderTest(test.TestCase):
sess_results["step_finished"])
self.assertAllEqual([-1] * 5, sess_results["step_outputs"].sample_id)
+ def testStepWithArgmaxEmbeddingInferenceSampler(self):
+ batch_size = 5
+ max_time = 8
+ vocabulary_size = 7
+ cell_depth = vocabulary_size # cell's logits must match vocabulary size
+ input_depth = 10
+ start_tokens = [0] * batch_size
+ end_token = 1
+
+ with self.test_session() as sess:
+ embeddings = np.random.randn(vocabulary_size,
+ input_depth).astype(np.float32)
+ cell = core_rnn_cell.LSTMCell(vocabulary_size)
+ sampler = sampling_decoder.ArgmaxEmbeddingInferenceSampler(
+ embeddings, start_tokens, end_token, max_time=max_time)
+ my_decoder = sampling_decoder.BasicSamplingDecoder(
+ cell=cell,
+ sampler=sampler,
+ initial_state=cell.zero_state(
+ dtype=dtypes.float32, batch_size=batch_size))
+ output_size = my_decoder.output_size
+ output_dtype = my_decoder.output_dtype
+ batch_size_t = my_decoder.batch_size
+ self.assertEqual(
+ sampling_decoder.SamplingDecoderOutput(cell_depth,
+ tensor_shape.TensorShape([])),
+ output_size)
+ self.assertEqual(
+ sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32),
+ output_dtype)
+
+ (first_finished, first_inputs, first_state) = my_decoder.initialize()
+ (step_outputs, step_state, step_next_inputs,
+ step_finished) = my_decoder.step(
+ constant_op.constant(0), first_inputs, first_state)
+
+ self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
+ self.assertTrue(
+ isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput))
+ self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
+ self.assertEqual((batch_size,), step_outputs[1].get_shape())
+ self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
+ self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
+ self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
+ self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())
+
+ sess.run(variables.global_variables_initializer())
+ sess_results = sess.run({
+ "batch_size": batch_size_t,
+ "first_finished": first_finished,
+ "first_inputs": first_inputs,
+ "first_state": first_state,
+ "step_outputs": step_outputs,
+ "step_state": step_state,
+ "step_next_inputs": step_next_inputs,
+ "step_finished": step_finished
+ })
+
+ expected_sample_ids = np.argmax(sess_results["step_outputs"].rnn_output,
+ -1)
+ expected_step_finished = (expected_sample_ids == end_token)
+ expected_step_next_inputs = embeddings[expected_sample_ids]
+ self.assertAllEqual([False, False, False, False, False],
+ sess_results["first_finished"])
+ self.assertAllEqual(expected_step_finished, sess_results["step_finished"])
+ self.assertAllEqual(expected_sample_ids,
+ sess_results["step_outputs"].sample_id)
+ self.assertAllEqual(expected_step_next_inputs,
+ sess_results["step_next_inputs"])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index 3ab6cb0e8c..7afe79e1f0 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
__all__ = ["Decoder", "dynamic_decode_rnn"]
@@ -132,7 +133,8 @@ def _create_zero_outputs(size, dtype, batch_size):
def dynamic_decode_rnn(decoder,
output_time_major=False,
parallel_iterations=32,
- swap_memory=False):
+ swap_memory=False,
+ scope=None):
"""Perform dynamic decoding with `decoder`.
Args:
@@ -143,6 +145,7 @@ def dynamic_decode_rnn(decoder,
time to the computation).
parallel_iterations: Argument passed to `tf.while_loop`.
swap_memory: Argument passed to `tf.while_loop`.
+ scope: Optional variable scope to use.
Returns:
`(final_outputs, final_state)`.
@@ -154,84 +157,92 @@ def dynamic_decode_rnn(decoder,
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
type(decoder))
- zero_outputs = _create_zero_outputs(decoder.output_size, decoder.output_dtype,
- decoder.batch_size)
-
- initial_finished, initial_inputs, initial_state = decoder.initialize()
- initial_time = constant_op.constant(0, dtype=dtypes.int32)
-
- def _shape(batch_size, from_shape):
- if not isinstance(from_shape, tensor_shape.TensorShape):
- return tensor_shape.TensorShape(None)
- else:
- batch_size = tensor_util.constant_value(
- ops.convert_to_tensor(
- batch_size, name="batch_size"))
- return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)
-
- def _create_ta(s, d):
- return tensor_array_ops.TensorArray(
- dtype=d, size=0, dynamic_size=True,
- element_shape=_shape(decoder.batch_size, s))
-
- initial_outputs_ta = nest.map_structure(
- _create_ta, decoder.output_size, decoder.output_dtype)
-
- def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
- finished):
- return math_ops.logical_not(math_ops.reduce_all(finished))
-
- def body(time, outputs_ta, state, inputs, finished):
- """Internal while_loop body.
-
- Args:
- time: scalar int32 tensor.
- outputs_ta: structure of TensorArray.
- state: (structure of) state tensors and TensorArrays.
- inputs: (structure of) input tensors.
- finished: 1-D bool tensor.
-
- Returns:
- `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`.
- """
- (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
- time, inputs, state)
- next_finished = math_ops.logical_or(decoder_finished, finished)
-
- nest.assert_same_structure(state, decoder_state)
- nest.assert_same_structure(outputs_ta, next_outputs)
- nest.assert_same_structure(inputs, next_inputs)
-
- # Zero out output values past finish
- emit = nest.map_structure(
- lambda out, zero: array_ops.where(finished, zero, out), next_outputs,
- zero_outputs)
-
- # Copy through states past finish
- def _maybe_copy_state(new, cur):
- return (new if isinstance(cur, tensor_array_ops.TensorArray) else
- array_ops.where(finished, cur, new))
-
- next_state = nest.map_structure(_maybe_copy_state, decoder_state, state)
- outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
- outputs_ta, emit)
- return (time + 1, outputs_ta, next_state, next_inputs, next_finished)
-
- res = control_flow_ops.while_loop(
- condition,
- body,
- loop_vars=[
- initial_time, initial_outputs_ta, initial_state, initial_inputs,
- initial_finished
- ],
- parallel_iterations=parallel_iterations,
- swap_memory=swap_memory)
-
- final_outputs_ta = res[1]
- final_state = res[2]
-
- final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
- if not output_time_major:
- final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)
+ with variable_scope.variable_scope(scope or "decoder") as varscope:
+ # Properly cache variable values inside the while_loop
+ if varscope.caching_device is None:
+ varscope.set_caching_device(lambda op: op.device)
+
+ zero_outputs = _create_zero_outputs(decoder.output_size,
+ decoder.output_dtype,
+ decoder.batch_size)
+
+ initial_finished, initial_inputs, initial_state = decoder.initialize()
+ initial_time = constant_op.constant(0, dtype=dtypes.int32)
+
+ def _shape(batch_size, from_shape):
+ if not isinstance(from_shape, tensor_shape.TensorShape):
+ return tensor_shape.TensorShape(None)
+ else:
+ batch_size = tensor_util.constant_value(
+ ops.convert_to_tensor(
+ batch_size, name="batch_size"))
+ return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)
+
+ def _create_ta(s, d):
+ return tensor_array_ops.TensorArray(
+ dtype=d,
+ size=0,
+ dynamic_size=True,
+ element_shape=_shape(decoder.batch_size, s))
+
+ initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
+ decoder.output_dtype)
+
+ def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
+ finished):
+ return math_ops.logical_not(math_ops.reduce_all(finished))
+
+ def body(time, outputs_ta, state, inputs, finished):
+ """Internal while_loop body.
+
+ Args:
+ time: scalar int32 tensor.
+ outputs_ta: structure of TensorArray.
+ state: (structure of) state tensors and TensorArrays.
+ inputs: (structure of) input tensors.
+ finished: 1-D bool tensor.
+
+ Returns:
+ `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`.
+ """
+ (next_outputs, decoder_state, next_inputs,
+ decoder_finished) = decoder.step(time, inputs, state)
+ next_finished = math_ops.logical_or(decoder_finished, finished)
+
+ nest.assert_same_structure(state, decoder_state)
+ nest.assert_same_structure(outputs_ta, next_outputs)
+ nest.assert_same_structure(inputs, next_inputs)
+
+ # Zero out output values past finish
+ emit = nest.map_structure(
+ lambda out, zero: array_ops.where(finished, zero, out), next_outputs,
+ zero_outputs)
+
+ # Copy through states past finish
+ def _maybe_copy_state(new, cur):
+ return (new if isinstance(cur, tensor_array_ops.TensorArray) else
+ array_ops.where(finished, cur, new))
+
+ next_state = nest.map_structure(_maybe_copy_state, decoder_state, state)
+ outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
+ outputs_ta, emit)
+ return (time + 1, outputs_ta, next_state, next_inputs, next_finished)
+
+ res = control_flow_ops.while_loop(
+ condition,
+ body,
+ loop_vars=[
+ initial_time, initial_outputs_ta, initial_state, initial_inputs,
+ initial_finished
+ ],
+ parallel_iterations=parallel_iterations,
+ swap_memory=swap_memory)
+
+ final_outputs_ta = res[1]
+ final_state = res[2]
+
+ final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
+ if not output_time_major:
+ final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)
return final_outputs, final_state
diff --git a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py
index c4654e535d..fc36c3eae0 100644
--- a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
@@ -184,7 +185,88 @@ class BasicTrainingSampler(Sampler):
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
sample_id = array_ops.tile([constant_op.constant(-1)], [self._batch_size])
+ def read_from_ta(inp):
+ return inp.read(next_time)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
- lambda: nest.map_structure(lambda inp: inp.read(next_time), self._input_tas))
+ lambda: nest.map_structure(read_from_ta, self._input_tas))
return (sample_id, finished, next_inputs)
+
+
+class ArgmaxEmbeddingInferenceSampler(Sampler):
+ """A (non-)sampler for use during inference.
+
+ Uses the argmax of the output (treated as logits) and passes the
+ result through an embedding layer to get the next input.
+ """
+
+ def __init__(self, embedding, start_tokens, end_token, max_time=None):
+ """Initializer.
+
+ Args:
+ embedding: A callable that takes a vector tensor of `ids` (argmax ids),
+ or the `params` argument for `embedding_lookup`.
+ start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
+ end_token: `int32` scalar, the token that marks end of decoding.
+ max_time: `int32` scalar, maximum allowed number of decoding steps.
+ Default is `None` (decode until `end_token` is seen).
+
+ Raises:
+ ValueError: if `sequence_length` is not a 1D tensor.
+ """
+ if callable(embedding):
+ self._embedding_fn = embedding
+ else:
+
+ def embedding_fn(ids):
+ return embedding_ops.embedding_lookup(embedding, ids)
+
+ self._embedding_fn = embedding_fn
+
+ self._start_tokens = ops.convert_to_tensor(
+ start_tokens, dtype=dtypes.int32, name="start_tokens")
+ self._end_token = ops.convert_to_tensor(
+ end_token, dtype=dtypes.int32, name="end_token")
+ if self._start_tokens.get_shape().ndims != 1:
+ raise ValueError("start_tokens must be a vector")
+ self._batch_size = array_ops.size(self._start_tokens)
+ if self._end_token.get_shape().ndims != 0:
+ raise ValueError("end_token must be a scalar")
+ if max_time is not None:
+ self._max_time = ops.convert_to_tensor(
+ max_time, dtype=dtypes.int32, name="max_time")
+ if self._max_time.get_shape().ndims != 0:
+ raise ValueError("max_time must be a scalar")
+ else:
+ self._max_time = None
+ self._start_inputs = self._embedding_fn(self._start_tokens)
+
+ @property
+ def batch_size(self):
+ return self._batch_size
+
+ def initialize(self):
+ if self._max_time is not None:
+ finished = array_ops.tile([math_ops.equal(self._max_time, 0)],
+ [self._batch_size])
+ else:
+ finished = array_ops.tile([False], [self._batch_size])
+ return (finished, self._start_inputs)
+
+ def sample(self, time, outputs, **unused_kwargs):
+ # Outputs are logits, use argmax to get the most probable id
+ if not isinstance(outputs, ops.Tensor):
+ raise TypeError("Expected outputs to be a single Tensor, got: %s" %
+ outputs)
+ sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32)
+ finished = math_ops.equal(sample_ids, self._end_token)
+ if self._max_time is not None:
+ finished = math_ops.logical_or(finished, time + 1 >= self._max_time)
+ all_finished = math_ops.reduce_all(finished)
+
+ next_inputs = control_flow_ops.cond(
+ all_finished,
+ # If we're finished, the next_inputs value doesn't matter
+ lambda: self._start_inputs,
+ lambda: self._embedding_fn(sample_ids))
+ return (sample_ids, finished, next_inputs)