diff options
author | 2017-01-25 14:54:36 -0800 | |
---|---|---|
committer | 2017-01-25 15:04:56 -0800 | |
commit | 8291536b277379ff37dbb8d3198d947776aaff2f (patch) | |
tree | d0362b7aa3a204fbfb6438080c7740f934b35caa | |
parent | 239493a6825f33c96d64b6a36be6616fbb41e42b (diff) |
Add ArgmaxEmbeddingInferenceSampler to new tf.contrib.seq2seq API
Also add a properly caching varscope to the decoder fn.
Change: 145600420
4 files changed, 250 insertions, 85 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py index b3c6c593c5..f884a83095 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py @@ -125,16 +125,17 @@ class DynamicDecodeRNNTest(test.TestCase): # Match the variable scope of dynamic_rnn below so we end up # using the same variables - with vs.variable_scope("rnn"): + with vs.variable_scope("root") as scope: final_decoder_outputs, final_decoder_state = decoder.dynamic_decode_rnn( - my_decoder) + my_decoder, scope=scope) - with vs.variable_scope(vs.get_variable_scope(), reuse=True): + with vs.variable_scope(scope, reuse=True) as scope: final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn( cell, inputs, sequence_length=sequence_length, - initial_state=zero_state) + initial_state=zero_state, + scope=scope) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py index ba945a0ecb..616fdcf554 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py @@ -104,6 +104,77 @@ class BasicSamplingDecoderTest(test.TestCase): sess_results["step_finished"]) self.assertAllEqual([-1] * 5, sess_results["step_outputs"].sample_id) + def testStepWithArgmaxEmbeddingInferenceSampler(self): + batch_size = 5 + max_time = 8 + vocabulary_size = 7 + cell_depth = vocabulary_size # cell's logits must match vocabulary size + input_depth = 10 + start_tokens = [0] * batch_size + end_token = 1 + + with self.test_session() as sess: + embeddings = np.random.randn(vocabulary_size, + input_depth).astype(np.float32) + cell = core_rnn_cell.LSTMCell(vocabulary_size) + sampler = sampling_decoder.ArgmaxEmbeddingInferenceSampler( + embeddings, start_tokens, end_token, max_time=max_time) + my_decoder = sampling_decoder.BasicSamplingDecoder( + cell=cell, + sampler=sampler, + initial_state=cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size)) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + batch_size_t = my_decoder.batch_size + self.assertEqual( + sampling_decoder.SamplingDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (first_finished, first_inputs, first_state) = my_decoder.initialize() + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + + self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + sess.run(variables.global_variables_initializer()) + sess_results = sess.run({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + expected_sample_ids = np.argmax(sess_results["step_outputs"].rnn_output, + -1) + expected_step_finished = (expected_sample_ids == end_token) + expected_step_next_inputs = embeddings[expected_sample_ids] + self.assertAllEqual([False, False, False, False, False], + sess_results["first_finished"]) + self.assertAllEqual(expected_step_finished, sess_results["step_finished"]) + self.assertAllEqual(expected_sample_ids, + sess_results["step_outputs"].sample_id) + self.assertAllEqual(expected_step_next_inputs, + sess_results["step_next_inputs"]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index 3ab6cb0e8c..7afe79e1f0 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest __all__ = ["Decoder", "dynamic_decode_rnn"] @@ -132,7 +133,8 @@ def _create_zero_outputs(size, dtype, batch_size): def dynamic_decode_rnn(decoder, output_time_major=False, parallel_iterations=32, - swap_memory=False): + swap_memory=False, + scope=None): """Perform dynamic decoding with `decoder`. Args: @@ -143,6 +145,7 @@ def dynamic_decode_rnn(decoder, time to the computation). parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. + scope: Optional variable scope to use. Returns: `(final_outputs, final_state)`. @@ -154,84 +157,92 @@ def dynamic_decode_rnn(decoder, raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) - zero_outputs = _create_zero_outputs(decoder.output_size, decoder.output_dtype, - decoder.batch_size) - - initial_finished, initial_inputs, initial_state = decoder.initialize() - initial_time = constant_op.constant(0, dtype=dtypes.int32) - - def _shape(batch_size, from_shape): - if not isinstance(from_shape, tensor_shape.TensorShape): - return tensor_shape.TensorShape(None) - else: - batch_size = tensor_util.constant_value( - ops.convert_to_tensor( - batch_size, name="batch_size")) - return tensor_shape.TensorShape([batch_size]).concatenate(from_shape) - - def _create_ta(s, d): - return tensor_array_ops.TensorArray( - dtype=d, size=0, dynamic_size=True, - element_shape=_shape(decoder.batch_size, s)) - - initial_outputs_ta = nest.map_structure( - _create_ta, decoder.output_size, decoder.output_dtype) - - def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs, - finished): - return math_ops.logical_not(math_ops.reduce_all(finished)) - - def body(time, outputs_ta, state, inputs, finished): - """Internal while_loop body. - - Args: - time: scalar int32 tensor. - outputs_ta: structure of TensorArray. - state: (structure of) state tensors and TensorArrays. - inputs: (structure of) input tensors. - finished: 1-D bool tensor. - - Returns: - `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`. - """ - (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step( - time, inputs, state) - next_finished = math_ops.logical_or(decoder_finished, finished) - - nest.assert_same_structure(state, decoder_state) - nest.assert_same_structure(outputs_ta, next_outputs) - nest.assert_same_structure(inputs, next_inputs) - - # Zero out output values past finish - emit = nest.map_structure( - lambda out, zero: array_ops.where(finished, zero, out), next_outputs, - zero_outputs) - - # Copy through states past finish - def _maybe_copy_state(new, cur): - return (new if isinstance(cur, tensor_array_ops.TensorArray) else - array_ops.where(finished, cur, new)) - - next_state = nest.map_structure(_maybe_copy_state, decoder_state, state) - outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), - outputs_ta, emit) - return (time + 1, outputs_ta, next_state, next_inputs, next_finished) - - res = control_flow_ops.while_loop( - condition, - body, - loop_vars=[ - initial_time, initial_outputs_ta, initial_state, initial_inputs, - initial_finished - ], - parallel_iterations=parallel_iterations, - swap_memory=swap_memory) - - final_outputs_ta = res[1] - final_state = res[2] - - final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta) - if not output_time_major: - final_outputs = nest.map_structure(_transpose_batch_time, final_outputs) + with variable_scope.variable_scope(scope or "decoder") as varscope: + # Properly cache variable values inside the while_loop + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) + + zero_outputs = _create_zero_outputs(decoder.output_size, + decoder.output_dtype, + decoder.batch_size) + + initial_finished, initial_inputs, initial_state = decoder.initialize() + initial_time = constant_op.constant(0, dtype=dtypes.int32) + + def _shape(batch_size, from_shape): + if not isinstance(from_shape, tensor_shape.TensorShape): + return tensor_shape.TensorShape(None) + else: + batch_size = tensor_util.constant_value( + ops.convert_to_tensor( + batch_size, name="batch_size")) + return tensor_shape.TensorShape([batch_size]).concatenate(from_shape) + + def _create_ta(s, d): + return tensor_array_ops.TensorArray( + dtype=d, + size=0, + dynamic_size=True, + element_shape=_shape(decoder.batch_size, s)) + + initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size, + decoder.output_dtype) + + def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs, + finished): + return math_ops.logical_not(math_ops.reduce_all(finished)) + + def body(time, outputs_ta, state, inputs, finished): + """Internal while_loop body. + + Args: + time: scalar int32 tensor. + outputs_ta: structure of TensorArray. + state: (structure of) state tensors and TensorArrays. + inputs: (structure of) input tensors. + finished: 1-D bool tensor. + + Returns: + `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`. + """ + (next_outputs, decoder_state, next_inputs, + decoder_finished) = decoder.step(time, inputs, state) + next_finished = math_ops.logical_or(decoder_finished, finished) + + nest.assert_same_structure(state, decoder_state) + nest.assert_same_structure(outputs_ta, next_outputs) + nest.assert_same_structure(inputs, next_inputs) + + # Zero out output values past finish + emit = nest.map_structure( + lambda out, zero: array_ops.where(finished, zero, out), next_outputs, + zero_outputs) + + # Copy through states past finish + def _maybe_copy_state(new, cur): + return (new if isinstance(cur, tensor_array_ops.TensorArray) else + array_ops.where(finished, cur, new)) + + next_state = nest.map_structure(_maybe_copy_state, decoder_state, state) + outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), + outputs_ta, emit) + return (time + 1, outputs_ta, next_state, next_inputs, next_finished) + + res = control_flow_ops.while_loop( + condition, + body, + loop_vars=[ + initial_time, initial_outputs_ta, initial_state, initial_inputs, + initial_finished + ], + parallel_iterations=parallel_iterations, + swap_memory=swap_memory) + + final_outputs_ta = res[1] + final_state = res[2] + + final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta) + if not output_time_major: + final_outputs = nest.map_structure(_transpose_batch_time, final_outputs) return final_outputs, final_state diff --git a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py index c4654e535d..fc36c3eae0 100644 --- a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest @@ -184,7 +185,88 @@ class BasicTrainingSampler(Sampler): finished = (next_time >= self._sequence_length) all_finished = math_ops.reduce_all(finished) sample_id = array_ops.tile([constant_op.constant(-1)], [self._batch_size]) + def read_from_ta(inp): + return inp.read(next_time) next_inputs = control_flow_ops.cond( all_finished, lambda: self._zero_inputs, - lambda: nest.map_structure(lambda inp: inp.read(next_time), self._input_tas)) + lambda: nest.map_structure(read_from_ta, self._input_tas)) return (sample_id, finished, next_inputs) + + +class ArgmaxEmbeddingInferenceSampler(Sampler): + """A (non-)sampler for use during inference. + + Uses the argmax of the output (treated as logits) and passes the + result through an embedding layer to get the next input. + """ + + def __init__(self, embedding, start_tokens, end_token, max_time=None): + """Initializer. + + Args: + embedding: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + max_time: `int32` scalar, maximum allowed number of decoding steps. + Default is `None` (decode until `end_token` is seen). + + Raises: + ValueError: if `sequence_length` is not a 1D tensor. + """ + if callable(embedding): + self._embedding_fn = embedding + else: + + def embedding_fn(ids): + return embedding_ops.embedding_lookup(embedding, ids) + + self._embedding_fn = embedding_fn + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._batch_size = array_ops.size(self._start_tokens) + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + if max_time is not None: + self._max_time = ops.convert_to_tensor( + max_time, dtype=dtypes.int32, name="max_time") + if self._max_time.get_shape().ndims != 0: + raise ValueError("max_time must be a scalar") + else: + self._max_time = None + self._start_inputs = self._embedding_fn(self._start_tokens) + + @property + def batch_size(self): + return self._batch_size + + def initialize(self): + if self._max_time is not None: + finished = array_ops.tile([math_ops.equal(self._max_time, 0)], + [self._batch_size]) + else: + finished = array_ops.tile([False], [self._batch_size]) + return (finished, self._start_inputs) + + def sample(self, time, outputs, **unused_kwargs): + # Outputs are logits, use argmax to get the most probable id + if not isinstance(outputs, ops.Tensor): + raise TypeError("Expected outputs to be a single Tensor, got: %s" % + outputs) + sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32) + finished = math_ops.equal(sample_ids, self._end_token) + if self._max_time is not None: + finished = math_ops.logical_or(finished, time + 1 >= self._max_time) + all_finished = math_ops.reduce_all(finished) + + next_inputs = control_flow_ops.cond( + all_finished, + # If we're finished, the next_inputs value doesn't matter + lambda: self._start_inputs, + lambda: self._embedding_fn(sample_ids)) + return (sample_ids, finished, next_inputs) |