diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/__init__.py | 0 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/decoder_fn.py | 249 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/layers.py | 35 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/seq2seq.py | 208 |
4 files changed, 457 insertions, 35 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/__init__.py b/tensorflow/contrib/seq2seq/python/ops/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/__init__.py diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder_fn.py b/tensorflow/contrib/seq2seq/python/ops/decoder_fn.py new file mode 100644 index 0000000000..d02efdc521 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/decoder_fn.py @@ -0,0 +1,249 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Seq2seq loss operations for use in neural networks. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn_cell +from tensorflow.python.util import nest + +__all__ = ["simple_decoder_fn_train", + "simple_decoder_fn_inference"] + +def simple_decoder_fn_train(encoder_state, name=None): + """ Simple decoder function for a sequence-to-sequence model used in the + `dynamic_rnn_decoder`. + + The `simple_decoder_fn_train` is a simple training function for a + sequence-to-sequence model. It should be used when `dynamic_rnn_decoder` is + in the training mode. + + The `simple_decoder_fn_train` is called with a set of the user arguments and + returns the `decoder_fn`, which can be passed to the `dynamic_rnn_decoder`, + such that + + ``` + dynamic_fn_train = simple_decoder_fn_train(encoder_state) + outputs_train, state_train = dynamic_rnn_decoder( + decoder_fn=dynamic_fn_train, ...) + ``` + + Further usage can be found in the `kernel_tests/seq2seq_test.py`. + + Args: + encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`. + name: (default: `None`) NameScope for the decoder function; + defaults to "simple_decoder_fn_train" + + Returns: + A decoder function with the required interface of `dynamic_rnn_decoder` + intended for training. + """ + with ops.name_scope(name, "simple_decoder_fn_train", [encoder_state]): + pass + + def decoder_fn(time, cell_state, cell_input, cell_output, context_state): + """ Decoder function used in the `dynamic_rnn_decoder` with the purpose of + training. + + Args: + time: positive integer constant reflecting the current timestep. + cell_state: state of RNNCell. + cell_input: input provided by `dynamic_rnn_decoder`. + cell_output: output of RNNCell. + context_state: context state provided by `dynamic_rnn_decoder`. + + Returns: + A tuple (done, next state, next input, emit output, next context state) + where: + + done: `None`, which is used by the `dynamic_rnn_decoder` to indicate + that `sequence_lengths` in `dynamic_rnn_decoder` should be used. + + next state: `cell_state`, this decoder function does not modify the + given state. + + next input: `cell_input`, this decoder function does not modify the + given input. The input could be modified when applying e.g. attention. + + emit output: `cell_output`, this decoder function does not modify the + given output. + + next context state: `context_state`, this decoder function does not + modify the given context state. The context state could be modified when + applying e.g. beam search. + """ + with ops.name_scope(name, "simple_decoder_fn_train", + [time, cell_state, cell_input, cell_output, + context_state]): + if cell_state is None: # first call, return encoder_state + return (None, encoder_state, cell_input, cell_output, context_state) + else: + return (None, cell_state, cell_input, cell_output, context_state) + return decoder_fn + + +def simple_decoder_fn_inference(output_fn, encoder_state, embeddings, + start_of_sequence_id, end_of_sequence_id, + maximum_length, num_decoder_symbols, + dtype=dtypes.int32, name=None): + """ Simple decoder function for a sequence-to-sequence model used in the + `dynamic_rnn_decoder`. + + The `simple_decoder_fn_inference` is a simple inference function for a + sequence-to-sequence model. It should be used when `dynamic_rnn_decoder` is + in the inference mode. + + The `simple_decoder_fn_inference` is called with a set of the user arguments + and returns the `decoder_fn`, which can be passed to the + `dynamic_rnn_decoder`, such that + + ``` + dynamic_fn_inference = simple_decoder_fn_inference(...) + outputs_inference, state_inference = dynamic_rnn_decoder( + decoder_fn=dynamic_fn_inference, ...) + ``` + + Further usage can be found in the `kernel_tests/seq2seq_test.py`. + + Args: + output_fn: An output function to project your `cell_output` onto class + logits. + + An example of an output function; + + ``` + tf.variable_scope("decoder") as varscope + output_fn = lambda x: layers.linear(x, num_decoder_symbols, + scope=varscope) + + outputs_train, state_train = seq2seq.dynamic_rnn_decoder(...) + logits_train = output_fn(outputs_train) + + varscope.reuse_variables() + logits_inference, state_inference = seq2seq.dynamic_rnn_decoder( + output_fn=output_fn, ...) + ``` + + If `None` is supplied it will act as an identity function, which + might be wanted when using the RNNCell `OutputProjectionWrapper`. + + encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`. + embeddings: The embeddings matrix used for the decoder sized + `[num_decoder_symbols, embedding_size]`. + start_of_sequence_id: The start of sequence ID in the decoder embeddings. + end_of_sequence_id: The end of sequence ID in the decoder embeddings. + maximum_length: The maximum allowed of time steps to decode. + num_decoder_symbols: The number of classes to decode at each time step. + dtype: (default: `dtypes.int32`) The default data type to use when + handling integer objects. + name: (default: `None`) NameScope for the decoder function; + defaults to "simple_decoder_fn_inference" + + Returns: + A decoder function with the required interface of `dynamic_rnn_decoder` + intended for inference. + """ + with ops.name_scope(name, "simple_decoder_fn_inference", + [output_fn, encoder_state, embeddings, + start_of_sequence_id, end_of_sequence_id, + maximum_length, num_decoder_symbols, dtype]): + start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, dtype) + end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, dtype) + maximum_length = ops.convert_to_tensor(maximum_length, dtype) + num_decoder_symbols = ops.convert_to_tensor(num_decoder_symbols, dtype) + encoder_info = nest.flatten(encoder_state)[0] + batch_size = encoder_info.get_shape()[0].value + if output_fn is None: + output_fn = lambda x: x + if batch_size is None: + batch_size = array_ops.shape(encoder_info)[0] + + def decoder_fn(time, cell_state, cell_input, cell_output, context_state): + """ Decoder function used in the `dynamic_rnn_decoder` with the purpose of + inference. + + The main difference between this decoder function and the `decoder_fn` in + `simple_decoder_fn_train` is how `next_cell_input` is calculated. In this + decoder function we calculate the next input by applying an argmax across + the feature dimension of the output from the decoder. This is a + greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014) + use beam-search instead. + + Args: + time: positive integer constant reflecting the current timestep. + cell_state: state of RNNCell. + cell_input: input provided by `dynamic_rnn_decoder`. + cell_output: output of RNNCell. + context_state: context state provided by `dynamic_rnn_decoder`. + + Returns: + A tuple (done, next state, next input, emit output, next context state) + where: + + done: A boolean vector to indicate which sentences has reached a + `end_of_sequence_id`. This is used for early stopping by the + `dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with + all elements as `true` is returned. + + next state: `cell_state`, this decoder function does not modify the + given state. + + next input: The embedding from argmax of the `cell_output` is used as + `next_input`. + + emit output: If `output_fn is None` the supplied `cell_output` is + returned, else the `output_fn` is used to update the `cell_output` + before calculating `next_input` and returning `cell_output`. + + next context state: `context_state`, this decoder function does not + modify the given context state. The context state could be modified when + applying e.g. beam search. + """ + with ops.name_scope(name, "simple_decoder_fn_inference", + [time, cell_state, cell_input, cell_output, + context_state]): + if cell_input is not None: + raise ValueError("Expected cell_input to be None, but saw: %s" % + cell_input) + if cell_output is None: + # invariant that this is time == 0 + next_input_id = array_ops.ones([batch_size,], dtype=dtype) * ( + start_of_sequence_id) + done = array_ops.zeros([batch_size,], dtype=dtypes.bool) + cell_state = encoder_state + cell_output = array_ops.zeros([num_decoder_symbols], + dtype=dtypes.float32) + else: + cell_output = output_fn(cell_output) + next_input_id = math_ops.cast( + math_ops.argmax(cell_output, 1), dtype=dtype) + done = math_ops.equal(next_input_id, end_of_sequence_id) + next_input = array_ops.gather(embeddings, next_input_id) + # if time > maxlen, return all true vector + done = control_flow_ops.cond(math_ops.greater(time, maximum_length), + lambda: array_ops.ones([batch_size,], dtype=dtypes.bool), + lambda: done) + return (done, cell_state, next_input, cell_output, context_state) + return decoder_fn diff --git a/tensorflow/contrib/seq2seq/python/ops/layers.py b/tensorflow/contrib/seq2seq/python/ops/layers.py deleted file mode 100644 index 4ee2df6073..0000000000 --- a/tensorflow/contrib/seq2seq/python/ops/layers.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Seq2seq layer operations for use in neural networks. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops import array_ops - - -__all__ = ["rnn_decoder", - "rnn_decoder_attention"] - - -def rnn_decoder(*args, **kwargs): - pass - - -def rnn_decoder_attention(*args, **kwargs): - pass diff --git a/tensorflow/contrib/seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/seq2seq/python/ops/seq2seq.py new file mode 100644 index 0000000000..4e15d669cb --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/seq2seq.py @@ -0,0 +1,208 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Seq2seq layer operations for use in neural networks. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import layers +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variable_scope as vs + +__all__ = ["dynamic_rnn_decoder"] + +def dynamic_rnn_decoder(cell, decoder_fn, inputs=None, sequence_length=None, + parallel_iterations=None, swap_memory=False, + time_major=False, scope=None, name=None): + """ Dynamic RNN decoder for a sequence-to-sequence model specified by + RNNCell and decoder function. + + The `dynamic_rnn_decoder` is similar to the `tf.python.ops.rnn.dynamic_rnn` + as the decoder does not make any assumptions of sequence length and batch + size of the input. + + The `dynamic_rnn_decoder` has two modes: training or inference and expects + the user to create seperate functions for each. + + Under both training and inference `cell` and `decoder_fn` is expected. Where + the `cell` performs computation at every timestep using the `raw_rnn` and + the `decoder_fn` allows modelling of early stopping, output, state, and next + input and context. + + When training the user is expected to supply `inputs`. At every time step a + slice of the supplied input is fed to the `decoder_fn`, which modifies and + returns the input for the next time step. + + `sequence_length` is needed at training time, i.e., when `inputs` is not + None, for dynamic unrolling. At test time, when `inputs` is None, + `sequence_length` is not needed. + + Under inference `inputs` is expected to be `None` and the input is inferred + solely from the `decoder_fn`. + + Args: + cell: An instance of RNNCell. + decoder_fn: A function that takes time, cell state, cell input, + cell output and context state. It returns a early stopping vector, + cell state, next input, cell output and context state. + Examples of decoder_fn can be found in the decoder_fn.py folder. + inputs: The inputs for decoding (embedded format). + + If `time_major == False` (default), this must be a `Tensor` of shape: + `[batch_size, max_time, ...]`. + + If `time_major == True`, this must be a `Tensor` of shape: + `[max_time, batch_size, ...]`. + + The input to `cell` at each time step will be a `Tensor` with dimensions + `[batch_size, ...]`. + sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. + if `inputs` is not None and `sequence_length` is None it is inferred + from the `inputs` as the maximal possible sequence length. + parallel_iterations: (Default: 32). The number of iterations to run in + parallel. Those operations which do not have any temporal dependency + and can be run in parallel, will be. This parameter trades off + time for space. Values >> 1 use more memory but take less time, + while smaller values use less memory but computations take longer. + swap_memory: Transparently swap the tensors produced in forward inference + but needed for back prop from GPU to CPU. This allows training RNNs + which would typically not fit on a single GPU, with very minimal (or no) + performance penalty. + time_major: The shape format of the `inputs` and `outputs` Tensors. + If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. + If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. + Using `time_major = True` is a bit more efficient because it avoids + transposes at the beginning and end of the RNN calculation. However, + most TensorFlow data is batch-major, so by default this function + accepts input and emits output in batch-major form. + scope: VariableScope for the `raw_rnn`; + defaults to None. + name: NameScope for the decoder; + defaults to "dynamic_rnn_decoder" + + Returns: + A pair (outputs, state) where: + + outputs: the RNN output 'Tensor'. + + If time_major == False (default), this will be a `Tensor` shaped: + `[batch_size, max_time, cell.output_size]`. + + If time_major == True, this will be a `Tensor` shaped: + `[max_time, batch_size, cell.output_size]`. + + state: The final state and will be shaped + `[batch_size, cell.state_size]`. + + Raises: + ValueError: if inputs is not None and has less than three dimensions. + """ + with ops.name_scope(name, "dynamic_rnn_decoder", + [cell, decoder_fn, inputs, sequence_length, + parallel_iterations, swap_memory, time_major, scope]): + if inputs is not None: + # Convert to tensor + inputs = ops.convert_to_tensor(inputs) + + # Test input dimensions + if inputs.get_shape().ndims is not None and ( + inputs.get_shape().ndims < 2): + raise ValueError("Inputs must have at least two dimensions") + # Setup of RNN (dimensions, sizes, length, initial state, dtype) + if not time_major: + # [batch, seq, features] -> [seq, batch, features] + inputs = array_ops.transpose(inputs, perm=[1, 0, 2]) + + dtype = inputs.dtype + # Get data input information + input_depth = int(inputs.get_shape()[2]) + batch_depth = inputs.get_shape()[1].value + max_time = inputs.get_shape()[0].value + if max_time is None: + max_time = array_ops.shape(inputs)[0] + # Setup decoder inputs as TensorArray + inputs_ta = tensor_array_ops.TensorArray(dtype, size=max_time) + inputs_ta = inputs_ta.unpack(inputs) + + def loop_fn(time, cell_output, cell_state, loop_state): + if cell_state is None: # first call, before while loop (in raw_rnn) + if cell_output is not None: + raise ValueError("Expected cell_output to be None when cell_state " + "is None, but saw: %s" % cell_output) + if loop_state is not None: + raise ValueError("Expected loop_state to be None when cell_state " + "is None, but saw: %s" % loop_state) + context_state = None + else: # subsequent calls, inside while loop, after cell excution + if isinstance(loop_state, tuple): + (done, context_state) = loop_state + else: + done = loop_state + context_state = None + + # call decoder function + if inputs is not None: # training + # get next_cell_input + if cell_state is None: + next_cell_input = inputs_ta.read(0) + else: + if batch_depth is not None: + batch_size = batch_depth + else: + batch_size = array_ops.shape(done)[0] + next_cell_input = control_flow_ops.cond( + math_ops.equal(time, max_time), + lambda: array_ops.zeros([batch_size, input_depth], dtype=dtype), + lambda: inputs_ta.read(time)) + (next_done, next_cell_state, next_cell_input, emit_output, + next_context_state) = decoder_fn(time, cell_state, next_cell_input, + cell_output, context_state) + else: # inference + # next_cell_input is obtained through decoder_fn + (next_done, next_cell_state, next_cell_input, emit_output, + next_context_state) = decoder_fn(time, cell_state, None, cell_output, + context_state) + + # check if we are done + if next_done is None: # training + next_done = time >= sequence_length + + # build next_loop_state + if next_context_state is None: + next_loop_state = next_done + else: + next_loop_state = (next_done, next_context_state) + + return (next_done, next_cell_input, next_cell_state, + emit_output, next_loop_state) + + # Run raw_rnn function + outputs_ta, state, _ = rnn.raw_rnn( + cell, loop_fn, parallel_iterations=parallel_iterations, + swap_memory=swap_memory, scope=scope) + outputs = outputs_ta.pack() + + if not time_major: + # [seq, batch, features] -> [batch, seq, features] + outputs = array_ops.transpose(outputs, perm=[1, 0, 2]) + return outputs, state |