From 8bd3b38e662a1298bebcada676c7cc6e2ea49c0f Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Fri, 6 Nov 2015 21:57:38 -0800 Subject: TensorFlow: Upstream changes to git. Changes: - Update a lot of documentation, installation instructions, requirements, etc. - Add RNN models directory for recurrent neural network examples to go along with the tutorials. Base CL: 107290480 --- CONTRIBUTING.md | 11 + README.md | 2 +- tensorflow/g3doc/api_docs/python/framework.md | 10 +- tensorflow/g3doc/api_docs/python/index.md | 1 - tensorflow/g3doc/api_docs/python/nn.md | 120 +--- tensorflow/g3doc/api_docs/python/state_ops.md | 4 +- tensorflow/g3doc/get_started/os_setup.md | 39 +- tensorflow/g3doc/how_tos/adding_an_op/index.md | 7 +- tensorflow/g3doc/how_tos/variables/index.md | 2 +- tensorflow/g3doc/resources/faq.md | 28 +- tensorflow/g3doc/resources/uses.md | 4 +- tensorflow/g3doc/tutorials/deep_cnn/index.md | 4 +- tensorflow/g3doc/tutorials/index.md | 4 +- tensorflow/g3doc/tutorials/mandelbrot/index.md | 2 +- .../g3doc/tutorials/mnist/beginners/index.md | 2 +- tensorflow/g3doc/tutorials/mnist/download/index.md | 2 +- tensorflow/g3doc/tutorials/mnist/pros/index.md | 2 +- tensorflow/g3doc/tutorials/mnist/tf/index.md | 2 +- tensorflow/g3doc/tutorials/pdes/index.md | 1 + tensorflow/g3doc/tutorials/seq2seq/index.md | 2 +- tensorflow/g3doc/tutorials/word2vec/index.md | 2 +- tensorflow/models/rnn/BUILD | 106 +++ tensorflow/models/rnn/README.md | 21 + tensorflow/models/rnn/__init__.py | 0 tensorflow/models/rnn/linear.py | 49 ++ tensorflow/models/rnn/linear_test.py | 35 + tensorflow/models/rnn/ptb/BUILD | 49 ++ tensorflow/models/rnn/ptb/__init__.py | 0 tensorflow/models/rnn/ptb/ptb_word_lm.py | 292 ++++++++ tensorflow/models/rnn/ptb/reader.py | 105 +++ tensorflow/models/rnn/ptb/reader_test.py | 47 ++ tensorflow/models/rnn/rnn.py | 128 ++++ tensorflow/models/rnn/rnn_cell.py | 605 +++++++++++++++++ tensorflow/models/rnn/rnn_cell_test.py | 154 +++++ tensorflow/models/rnn/rnn_test.py | 472 +++++++++++++ tensorflow/models/rnn/seq2seq.py | 749 +++++++++++++++++++++ tensorflow/models/rnn/seq2seq_test.py | 384 +++++++++++ tensorflow/models/rnn/translate/BUILD | 71 ++ tensorflow/models/rnn/translate/__init__.py | 0 tensorflow/models/rnn/translate/data_utils.py | 264 ++++++++ tensorflow/models/rnn/translate/seq2seq_model.py | 268 ++++++++ tensorflow/models/rnn/translate/translate.py | 260 +++++++ tensorflow/python/framework/docs.py | 8 + tensorflow/python/framework/gen_docs_combined.py | 3 +- tensorflow/python/ops/embedding_ops.py | 31 +- tensorflow/python/ops/nn.py | 27 +- tensorflow/tools/docker/Dockerfile.cpu | 4 + 47 files changed, 4223 insertions(+), 160 deletions(-) create mode 100644 tensorflow/models/rnn/BUILD create mode 100644 tensorflow/models/rnn/README.md create mode 100755 tensorflow/models/rnn/__init__.py create mode 100644 tensorflow/models/rnn/linear.py create mode 100644 tensorflow/models/rnn/linear_test.py create mode 100644 tensorflow/models/rnn/ptb/BUILD create mode 100755 tensorflow/models/rnn/ptb/__init__.py create mode 100644 tensorflow/models/rnn/ptb/ptb_word_lm.py create mode 100644 tensorflow/models/rnn/ptb/reader.py create mode 100644 tensorflow/models/rnn/ptb/reader_test.py create mode 100644 tensorflow/models/rnn/rnn.py create mode 100644 tensorflow/models/rnn/rnn_cell.py create mode 100644 tensorflow/models/rnn/rnn_cell_test.py create mode 100644 tensorflow/models/rnn/rnn_test.py create mode 100644 tensorflow/models/rnn/seq2seq.py create mode 100644 tensorflow/models/rnn/seq2seq_test.py create mode 100644 tensorflow/models/rnn/translate/BUILD create mode 100755 tensorflow/models/rnn/translate/__init__.py create mode 100644 tensorflow/models/rnn/translate/data_utils.py create mode 100644 tensorflow/models/rnn/translate/seq2seq_model.py create mode 100644 tensorflow/models/rnn/translate/translate.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dbaf281844..98339fe54a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,3 +15,14 @@ Follow either of the two links above to access the appropriate CLA and instructi ***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the main repository. +## Contributing code + +We currently use Gerrit to host and handle code changes to TensorFlow. The main +site is +[https://tensorflow-review.googlesource.com/](https://tensorflow-review.googlesource.com/). +See Gerrit [docs](https://gerrit-review.googlesource.com/Documentation/) for +information on how Gerrit's code review system works. + +We are currently working on improving our external acceptance process, so +please be patient with us as we work out the details. + diff --git a/README.md b/README.md index 2f8dac4a62..05bfec8431 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ variety of other domains, as well. # Download and Setup For detailed installation instructions, see -[here](g3doc/get_started/os_setup.md). +[here](tensorflow/g3doc/get_started/os_setup.md). ## Binary Installation diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md index 0614c68e15..1fc659ef0b 100644 --- a/tensorflow/g3doc/api_docs/python/framework.md +++ b/tensorflow/g3doc/api_docs/python/framework.md @@ -27,7 +27,7 @@ * [class tf.RegisterShape](#RegisterShape) * [class tf.TensorShape](#TensorShape) * [class tf.Dimension](#Dimension) - * [tf.op_scope(*args, **kwds)](#op_scope) + * [tf.op_scope(values, name, default_name)](#op_scope) * [tf.get_seed(op_seed)](#get_seed) @@ -235,7 +235,7 @@ def my_func(pred, tensor): - - - -#### tf.Graph.device(*args, **kwds) {#Graph.device} +#### tf.Graph.device(device_name_or_function) {#Graph.device} Returns a context manager that specifies the default device to use. @@ -287,7 +287,7 @@ with g.device(matmul_on_gpu): - - - -#### tf.Graph.name_scope(*args, **kwds) {#Graph.name_scope} +#### tf.Graph.name_scope(name) {#Graph.name_scope} Returns a context manager that creates hierarchical names for operations. @@ -611,7 +611,7 @@ the default graph. - - - -#### tf.Graph.gradient_override_map(*args, **kwds) {#Graph.gradient_override_map} +#### tf.Graph.gradient_override_map(op_type_map) {#Graph.gradient_override_map} EXPERIMENTAL: A context manager for overriding gradient functions. @@ -2023,7 +2023,7 @@ The value of this dimension, or None if it is unknown. - - - -### tf.op_scope(*args, **kwds)
{#op_scope}
+### tf.op_scope(values, name, default_name)
{#op_scope}
Returns a context manager for use when defining a Python op. diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index 8750d0aadd..dd47b703fc 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -267,7 +267,6 @@ * [`depthwise_conv2d`](nn.md#depthwise_conv2d) * [`dropout`](nn.md#dropout) * [`embedding_lookup`](nn.md#embedding_lookup) - * [`embedding_lookup_sparse`](nn.md#embedding_lookup_sparse) * [`fixed_unigram_candidate_sampler`](nn.md#fixed_unigram_candidate_sampler) * [`in_top_k`](nn.md#in_top_k) * [`l2_loss`](nn.md#l2_loss) diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index 50c460b68c..b129506107 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -35,7 +35,6 @@ accepted by [`tf.convert_to_tensor`](framework.md#convert_to_tensor). * [tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)](#softmax_cross_entropy_with_logits) * [Embeddings](#AUTOGENERATED-embeddings) * [tf.nn.embedding_lookup(params, ids, name=None)](#embedding_lookup) - * [tf.nn.embedding_lookup_sparse(params, sp_ids, sp_weights, name=None, combiner='mean')](#embedding_lookup_sparse) * [Evaluation](#AUTOGENERATED-evaluation) * [tf.nn.top_k(input, k, name=None)](#top_k) * [tf.nn.in_top_k(predictions, targets, k, name=None)](#in_top_k) @@ -130,17 +129,18 @@ sum is unchanged. By default, each element is kept or dropped independently. If `noise_shape` is specified, it must be [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -to the shape of `x`, and only dimensions with `noise_shape[i] == x.shape[i]` -will make independent decisions. For example, if `x.shape = [b, x, y, c]` and -`noise_shape = [b, 1, 1, c]`, each batch and channel component will be +to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]` +will make independent decisions. For example, if `shape(x) = [k, l, m, n]` +and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be kept independently and each row and column will be kept or not kept together. ##### Args: * x: A tensor. -* keep_prob: Float probability that each element is kept. -* noise_shape: Shape for randomly generated keep/drop flags. +* keep_prob: A Python float. The probability that each element is kept. +* noise_shape: A 1-D `Tensor` of type `int32`, representing the + shape for randomly generated keep/drop flags. * seed: A Python integer. Used to create a random seed. See [`set_random_seed`](constant_op.md#set_random_seed) for behavior. * name: A name for this operation (optional). @@ -247,10 +247,10 @@ are as follows. If the 4-D `input` has shape `[batch, in_height, in_width, ...]` and the 4-D `filter` has shape `[filter_height, filter_width, ...]`, then - output.shape = [batch, - (in_height - filter_height + 1) / strides[1], - (in_width - filter_width + 1) / strides[2], - ...] + shape(output) = [batch, + (in_height - filter_height + 1) / strides[1], + (in_width - filter_width + 1) / strides[2], + ...] output[b, i, j, :] = sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] * @@ -262,7 +262,7 @@ vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]` is multiplied by a vector `filter[di, dj, k]`, and all the vectors are concatenated. -In the formula for `output.shape`, the rounding direction depends on padding: +In the formula for `shape(output)`, the rounding direction depends on padding: * `padding = 'SAME'`: Round down (only full size windows are considered). * `padding = 'VALID'`: Round up (partial windows are included). @@ -411,7 +411,7 @@ In detail, the output is for each tuple of indices `i`. The output shape is - output.shape = (value.shape - ksize + 1) / strides + shape(output) = (shape(value) - ksize + 1) / strides where the rounding direction depends on padding: @@ -722,103 +722,43 @@ and the same dtype (either `float32` or `float64`). ## Embeddings
{#AUTOGENERATED-embeddings}
-TensorFlow provides several operations that help you compute embeddings. +TensorFlow provides library support for looking up values in embedding +tensors. - - - ### tf.nn.embedding_lookup(params, ids, name=None)
{#embedding_lookup}
-Return a tensor of embedding values by looking up "ids" in "params". +Looks up `ids` in a list of embedding tensors. -##### Args: - - -* params: List of tensors of the same shape. A single tensor is - treated as a singleton list. -* ids: Tensor of integers containing the ids to be looked up in - 'params'. Let P be len(params). If P > 1, then the ids are - partitioned by id % P, and we do separate lookups in params[p] - for 0 <= p < P, and then stitch the results back together into - a single result tensor. -* name: Optional name for the op. - -##### Returns: - - A tensor of shape ids.shape + params[0].shape[1:] containing the - values params[i % P][i] for each i in ids. - -##### Raises: - - -* ValueError: if some parameters are invalid. +This function is used to perform parallel lookups on the list of +tensors in `params`. It is a generalization of +[`tf.gather()`](array_ops.md#gather), where `params` is interpreted +as a partition of a larger embedding tensor. +If `len(params) > 1`, each element `id` of `ids` is partitioned between +the elements of `params` by computing `p = id % len(params)`, and is +then used to look up the slice `params[p][id // len(params), ...]`. -- - - - -### tf.nn.embedding_lookup_sparse(params, sp_ids, sp_weights, name=None, combiner='mean')
{#embedding_lookup_sparse}
- -Computes embeddings for the given ids and weights. - -This op assumes that there is at least one id for each row in the dense tensor -represented by sp_ids (i.e. there are no rows with empty features), and that -all the indices of sp_ids are in canonical row-major order. - -It also assumes that all id values lie in the range [0, p0), where p0 -is the sum of the size of params along dimension 0. +The results of the lookup are then concatenated into a dense +tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. ##### Args: -* params: A single tensor representing the complete embedding tensor, - or a list of P tensors all of same shape except for the first dimension, - representing sharded embedding tensors. In the latter case, the ids are - partitioned by id % P, and we do separate lookups in params[p] for - 0 <= p < P, and then stitch the results back together into a single - result tensor. The first dimension is allowed to vary as the vocab - size is not necessarily a multiple of P. -* sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId), - where N is typically batch size and M is arbitrary. -* sp_weights: either a SparseTensor of float / double weights, or None to - indicate all weights should be taken to be 1. If specified, sp_weights - must have exactly the same shape and indices as sp_ids. -* name: Optional name for the op. -* combiner: A string specifying the reduction op. Currently "mean" and "sum" - are supported. - "sum" computes the weighted sum of the embedding results for each row. - "mean" is the weighted sum divided by the total weight. +* params: A list of tensors with the same shape and type. +* ids: A `Tensor` with type `int32` containing the ids to be looked + up in `params`. +* name: A name for the operation (optional). ##### Returns: - A dense tensor representing the combined embeddings for the - sparse ids. For each row in the dense tensor represented by sp_ids, the op - looks up the embeddings for all ids in that row, multiplies them by the - corresponding weight, and combines these embeddings as specified. - - In other words, if - shape(combined params) = [p0, p1, ..., pm] - and - shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn] - then - shape(output) = [d0, d1, ..., dn-1, p1, ..., pm]. - - For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are - - [0, 0]: id 1, weight 2.0 - [0, 1]: id 3, weight 0.5 - [1, 0]: id 0, weight 1.0 - [2, 3]: id 1, weight 3.0 - - with combiner="mean", then the output will be a 3x20 matrix where - output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) - output[1, :] = params[0, :] * 1.0 - output[2, :] = params[1, :] * 3.0 + A `Tensor` with the same type as the tensors in `params`. ##### Raises: -* TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither - None nor SparseTensor. -* ValueError: If combiner is not one of {"mean", "sum"}. +* ValueError: If `params` is empty. diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md index 25bbe55f7b..70685a65bc 100644 --- a/tensorflow/g3doc/api_docs/python/state_ops.md +++ b/tensorflow/g3doc/api_docs/python/state_ops.md @@ -23,7 +23,7 @@ accepted by [`tf.convert_to_tensor`](framework.md#convert_to_tensor). * [Sharing Variables](#AUTOGENERATED-sharing-variables) * [tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None, trainable=True, collections=None)](#get_variable) * [tf.get_variable_scope()](#get_variable_scope) - * [tf.variable_scope(*args, **kwds)](#variable_scope) + * [tf.variable_scope(name_or_scope, reuse=None, initializer=None)](#variable_scope) * [tf.constant_initializer(value=0.0)](#constant_initializer) * [tf.random_normal_initializer(mean=0.0, stddev=1.0, seed=None)](#random_normal_initializer) * [tf.truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None)](#truncated_normal_initializer) @@ -896,7 +896,7 @@ Returns the current variable scope. - - - -### tf.variable_scope(*args, **kwds)
{#variable_scope}
+### tf.variable_scope(name_or_scope, reuse=None, initializer=None)
{#variable_scope}
Returns a context for variable scope. diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index f8113bcaec..0917f16832 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -36,7 +36,33 @@ Install TensorFlow (only CPU binary version is currently available). $ sudo pip install https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl ``` -### Try your first TensorFlow program +## Docker-based installation + +We also support running TensorFlow via [Docker](http://docker.com/), which lets +you avoid worrying about setting up dependencies. + +First, [install Docker](http://docs.docker.com/engine/installation/). Once +Docker is up and running, you can start a container with one command: + +```sh +$ docker run -it b.gcr.io/tensorflow/tensorflow +``` + +This will start a container with TensorFlow and all its dependencies already +installed. + +### Additional images + +The default Docker image above contains just a minimal set of libraries for +getting up and running with TensorFlow. We also have several other containers, +which you can use in the `docker run` command above: + +* `b.gcr.io/tensorflow/tensorflow-full`: Contains a complete TensorFlow source + installation, including all utilities needed to build and run TensorFlow. This + makes it easy to experiment directly with the source, without needing to + install any of the dependencies described above. + +## Try your first TensorFlow program ```sh $ python @@ -133,6 +159,13 @@ $ sudo apt-get install python-numpy swig python-dev In order to build TensorFlow with GPU support, both Cuda Toolkit 7.0 and CUDNN 6.5 V2 from NVIDIA need to be installed. +TensorFlow GPU support requires having a GPU card with NVidia Compute Capability >= 3.5. Supported cards include but are not limited to: + +* NVidia Titan +* NVidia Titan X +* NVidia K20 +* NVidia K40 + ##### Download and install Cuda Toolkit 7.0 https://developer.nvidia.com/cuda-toolkit-70 @@ -227,7 +260,7 @@ Notes : You need to install Follow installation instructions [here](http://docs.scipy.org/doc/numpy/user/install.html). -### Create the pip package and install +### Create the pip package and install {#create-pip} ```sh $ bazel build -c opt //tensorflow/tools/pip_package:build_pip_package @@ -238,7 +271,7 @@ $ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg $ pip install /tmp/tensorflow_pkg/tensorflow-0.5.0-cp27-none-linux_x86_64.whl ``` -### Train your first TensorFlow neural net model +## Train your first TensorFlow neural net model From the root of your source tree, run: diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md index 5c6243cd9c..8702569f75 100644 --- a/tensorflow/g3doc/how_tos/adding_an_op/index.md +++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md @@ -127,10 +127,9 @@ To do this for the `ZeroOut` op, add the following to `zero_out.cc`: REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp); ``` -TODO: instructions or pointer to building TF - -At this point, the Tensorflow system can reference and use the Op when -requested. +Once you +[build and reinstall TensorFlow](../../get_started/os_setup.md#create-pip), the +Tensorflow system can reference and use the Op when requested. ## Generate the client wrapper
{#AUTOGENERATED-generate-the-client-wrapper}
### The Python Op wrapper
{#AUTOGENERATED-the-python-op-wrapper}
diff --git a/tensorflow/g3doc/how_tos/variables/index.md b/tensorflow/g3doc/how_tos/variables/index.md index 4ad8f8a266..26b19b3ae1 100644 --- a/tensorflow/g3doc/how_tos/variables/index.md +++ b/tensorflow/g3doc/how_tos/variables/index.md @@ -101,7 +101,7 @@ w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice") The convenience function `tf.initialize_all_variables()` adds an Op to initialize *all variables* in the model. You can also pass it an explicit list of variables to initialize. See the -[Variables Documentation](../../api_docs/python/state_op.md) for more options, +[Variables Documentation](../../api_docs/python/state_ops.md) for more options, including checking if variables are initialized. ## Saving and Restoring diff --git a/tensorflow/g3doc/resources/faq.md b/tensorflow/g3doc/resources/faq.md index 2bd485e7f9..a2b9a58e08 100644 --- a/tensorflow/g3doc/resources/faq.md +++ b/tensorflow/g3doc/resources/faq.md @@ -6,18 +6,18 @@ answer on one of the TensorFlow [community resources](index.md). ## Contents - * [Building a TensorFlow graph](#AUTOGENERATED-building-a-tensorflow-graph) - * [Running a TensorFlow computation](#AUTOGENERATED-running-a-tensorflow-computation) - * [Variables](#AUTOGENERATED-variables) - * [Tensor shapes](#AUTOGENERATED-tensor-shapes) - * [TensorBoard](#AUTOGENERATED-tensorboard) - * [Extending TensorFlow](#AUTOGENERATED-extending-tensorflow) - * [Miscellaneous](#AUTOGENERATED-miscellaneous) +* [Building a TensorFlow graph](#AUTOGENERATED-building-a-tensorflow-graph) +* [Running a TensorFlow computation](#AUTOGENERATED-running-a-tensorflow-computation) +* [Variables](#AUTOGENERATED-variables) +* [Tensor shapes](#AUTOGENERATED-tensor-shapes) +* [TensorBoard](#AUTOGENERATED-tensorboard) +* [Extending TensorFlow](#AUTOGENERATED-extending-tensorflow) +* [Miscellaneous](#AUTOGENERATED-miscellaneous) -### Building a TensorFlow graph
{#AUTOGENERATED-building-a-tensorflow-graph}
+## Building a TensorFlow graph
{#AUTOGENERATED-building-a-tensorflow-graph}
See also the [API documentation on building graphs](../api_docs/python/framework.md). @@ -55,7 +55,7 @@ uses multiple GPUs. TensorFlow supports a variety of different data types and tensor shapes. See the [ranks, shapes, and types reference](dims_types.md) for more details. -### Running a TensorFlow computation
{#AUTOGENERATED-running-a-tensorflow-computation}
+## Running a TensorFlow computation
{#AUTOGENERATED-running-a-tensorflow-computation}
See also the [API documentation on running graphs](../api_docs/python/client.md). @@ -175,7 +175,7 @@ for [using `QueueRunner` objects to drive queues and readers](../how_tos/reading_data/index.md#QueueRunners) for more information on how to use them. -### Variables
{#AUTOGENERATED-variables}
+## Variables
{#AUTOGENERATED-variables}
See also the how-to documentation on [variables](../how_tos/variables/index.md) and [variable scopes](../how_tos/variable_scope/index.md), and @@ -196,7 +196,7 @@ operations to a variable are allowed to run with no mutual exclusion. To acquire a lock when assigning to a variable, pass `use_locking=True` to [`Variable.assign()`](../api_docs/python/state_ops.md#Variable.assign). -### Tensor shapes
{#AUTOGENERATED-tensor-shapes}
+## Tensor shapes
{#AUTOGENERATED-tensor-shapes}
See also the [`TensorShape` API documentation](../api_docs/python/framework.md#TensorShape). @@ -248,7 +248,7 @@ to encode the batch size as a Python constant, but instead to use a symbolic [`tf.placeholder(..., shape=[None, ...])`](../api_docs/python/io_ops.md#placeholder). The `None` element of the shape corresponds to a variable-sized dimension. -### TensorBoard
{#AUTOGENERATED-tensorboard}
+## TensorBoard
{#AUTOGENERATED-tensorboard}
See also the [how-to documentation on TensorBoard](../how_tos/graph_viz/index.md). @@ -260,7 +260,7 @@ of these summaries to a log directory. Then, startup TensorBoard using and pass the --logdir flag so that it points to your log directory. For more details, see . -### Extending TensorFlow
{#AUTOGENERATED-extending-tensorflow}
+## Extending TensorFlow
{#AUTOGENERATED-extending-tensorflow}
See also the how-to documentation for [adding a new operation to TensorFlow](../how_tos/adding_an_op/index.md). @@ -293,7 +293,7 @@ how-to documentation for [adding an op with a list of inputs or outputs](../how_tos/adding_an_op/index.md#list-input-output) for more details of how to define these different input types. -### Miscellaneous
{#AUTOGENERATED-miscellaneous}
+## Miscellaneous
{#AUTOGENERATED-miscellaneous}
#### Does TensorFlow work with Python 3? diff --git a/tensorflow/g3doc/resources/uses.md b/tensorflow/g3doc/resources/uses.md index f73d4e92d7..cc212886c5 100644 --- a/tensorflow/g3doc/resources/uses.md +++ b/tensorflow/g3doc/resources/uses.md @@ -17,7 +17,7 @@ Listed below are some of the many uses of TensorFlow. * **Inception Image Classification Model** * **Organization**: Google - * **Description**: Baseline model and follow on research into highly accurate computer vision models, starting with the model that won the 2014 Imagenet iamge classification challenge + * **Description**: Baseline model and follow on research into highly accurate computer vision models, starting with the model that won the 2014 Imagenet image classification challenge * **More Info**: Baseline model described in [Arxiv paper](http://arxiv.org/abs/1409.4842) * **SmartReply** @@ -33,6 +33,6 @@ Listed below are some of the many uses of TensorFlow. * **On-Device Computer Vision for OCR** * **Organization**: Google - * **Description**: On-device computer vision model to do optical character recoignition to enable real-time translation. + * **Description**: On-device computer vision model to do optical character recognition to enable real-time translation. * **More info**: [Google Research blog post](http://googleresearch.blogspot.com/2015/07/how-google-translate-squeezes-deep.html) } diff --git a/tensorflow/g3doc/tutorials/deep_cnn/index.md b/tensorflow/g3doc/tutorials/deep_cnn/index.md index 0bbba2cb40..929e1b3047 100644 --- a/tensorflow/g3doc/tutorials/deep_cnn/index.md +++ b/tensorflow/g3doc/tutorials/deep_cnn/index.md @@ -1,4 +1,4 @@ -# Convolutional Neural Networks for Object Recognition +# Convolutional Neural Networks **NOTE:** This tutorial is intended for *advanced* users of TensorFlow and assumes expertise and experience in machine learning. @@ -264,7 +264,7 @@ in `cifar10.py`. `cifar10_train.py` periodically [saves](../../api_docs/python/state_ops.md#Saver) all model parameters in -[checkpoint files](../../how_tos/variables.md#saving-and-restoring) +[checkpoint files](../../how_tos/variables/index.md#saving-and-restoring) but it does *not* evaluate the model. The checkpoint file will be used by `cifar10_eval.py` to measure the predictive performance (see [Evaluating a Model](#evaluating-a-model) below). diff --git a/tensorflow/g3doc/tutorials/index.md b/tensorflow/g3doc/tutorials/index.md index 470645db51..202b87c73c 100644 --- a/tensorflow/g3doc/tutorials/index.md +++ b/tensorflow/g3doc/tutorials/index.md @@ -1,7 +1,7 @@ # Overview -## ML for Beginners +## MNIST For ML Beginners If you're new to machine learning, we recommend starting here. You'll learn about a classic problem, handwritten digit classification (MNIST), and get a @@ -10,7 +10,7 @@ gentle introduction to multiclass classification. [View Tutorial](mnist/beginners/index.md) -## MNIST for Pros +## Deep MNIST for Experts If you're already familiar with other deep learning software packages, and are already familiar with MNIST, this tutorial with give you a very brief primer on diff --git a/tensorflow/g3doc/tutorials/mandelbrot/index.md b/tensorflow/g3doc/tutorials/mandelbrot/index.md index 7c6adcb4e8..b3d5a185f9 100755 --- a/tensorflow/g3doc/tutorials/mandelbrot/index.md +++ b/tensorflow/g3doc/tutorials/mandelbrot/index.md @@ -1,4 +1,4 @@ - +# Mandelbrot Set ``` #Import libraries for simulation diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md index 398eca5f18..fff7484959 100644 --- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md +++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md @@ -1,4 +1,4 @@ -# MNIST Softmax Regression (For Beginners) +# MNIST For ML Beginners *This tutorial is intended for readers who are new to both machine learning and TensorFlow. If you already diff --git a/tensorflow/g3doc/tutorials/mnist/download/index.md b/tensorflow/g3doc/tutorials/mnist/download/index.md index dc11e727d8..df6245df78 100644 --- a/tensorflow/g3doc/tutorials/mnist/download/index.md +++ b/tensorflow/g3doc/tutorials/mnist/download/index.md @@ -1,4 +1,4 @@ -# Downloading MNIST +# MNIST Data Download Code: [tensorflow/g3doc/tutorials/mnist/](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/) diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md index cb0292586b..34853ccf66 100644 --- a/tensorflow/g3doc/tutorials/mnist/pros/index.md +++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md @@ -1,4 +1,4 @@ -# MNIST Deep Learning Example (For Experts) +# Deep MNIST for Experts TensorFlow is a powerful library for doing large-scale numerical computation. One of the tasks at which it excels is implementing and training deep neural diff --git a/tensorflow/g3doc/tutorials/mnist/tf/index.md b/tensorflow/g3doc/tutorials/mnist/tf/index.md index 86f3296287..5ce996af12 100644 --- a/tensorflow/g3doc/tutorials/mnist/tf/index.md +++ b/tensorflow/g3doc/tutorials/mnist/tf/index.md @@ -1,4 +1,4 @@ -# Handwritten Digit Classification +# TensorFlow Mechanics 101 Code: [tensorflow/g3doc/tutorials/mnist/](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/) diff --git a/tensorflow/g3doc/tutorials/pdes/index.md b/tensorflow/g3doc/tutorials/pdes/index.md index 1f29e4037c..26f36d5536 100755 --- a/tensorflow/g3doc/tutorials/pdes/index.md +++ b/tensorflow/g3doc/tutorials/pdes/index.md @@ -1,3 +1,4 @@ +# Partial Differential Equations ## Basic Setup diff --git a/tensorflow/g3doc/tutorials/seq2seq/index.md b/tensorflow/g3doc/tutorials/seq2seq/index.md index e421c814aa..3eec2a2ba8 100644 --- a/tensorflow/g3doc/tutorials/seq2seq/index.md +++ b/tensorflow/g3doc/tutorials/seq2seq/index.md @@ -1,4 +1,4 @@ -# Sequence-to-Sequence Models: Learning to Translate +# Sequence-to-Sequence Models Recurrent neural networks can learn to model language, as already discussed in the [RNN Tutorial](../recurrent/index.md) diff --git a/tensorflow/g3doc/tutorials/word2vec/index.md b/tensorflow/g3doc/tutorials/word2vec/index.md index 832c7c166b..290ff3627f 100644 --- a/tensorflow/g3doc/tutorials/word2vec/index.md +++ b/tensorflow/g3doc/tutorials/word2vec/index.md @@ -1,4 +1,4 @@ -# Learning Vector Representations of Words +# Vector Representations of Words In this tutorial we look at the word2vec model by [Mikolov et al.](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). diff --git a/tensorflow/models/rnn/BUILD b/tensorflow/models/rnn/BUILD new file mode 100644 index 0000000000..a88d48fd42 --- /dev/null +++ b/tensorflow/models/rnn/BUILD @@ -0,0 +1,106 @@ +# Description: +# Example RNN models, including language models and sequence-to-sequence models. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("/tensorflow/tensorflow", "cuda_py_tests") + +py_library( + name = "linear", + srcs = [ + "linear.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "linear_test", + size = "small", + srcs = ["linear_test.py"], + deps = [ + ":linear", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "rnn_cell", + srcs = [ + "rnn_cell.py", + ], + deps = [ + ":linear", + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "rnn_cell_test", + size = "small", + srcs = ["rnn_cell_test.py"], + deps = [ + ":rnn_cell", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "rnn", + srcs = [ + "rnn.py", + ], + deps = [ + ":rnn_cell", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_tests( + name = "rnn_tests", + srcs = [ + "rnn_test.py", + ], + additional_deps = [ + ":rnn", + ], +) + +py_library( + name = "seq2seq", + srcs = [ + "seq2seq.py", + ], + deps = [ + ":rnn", + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "seq2seq_test", + srcs = [ + "seq2seq_test.py", + ], + deps = [ + ":seq2seq", + "//tensorflow:tensorflow_py", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/models/rnn/README.md b/tensorflow/models/rnn/README.md new file mode 100644 index 0000000000..227c226c3a --- /dev/null +++ b/tensorflow/models/rnn/README.md @@ -0,0 +1,21 @@ +This directory contains functions for creating recurrent neural networks +and sequence-to-sequence models. Detailed instructions on how to get started +and use them are available in the tutorials. + +* [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/) +* [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/) + +Here is a short overview of what is in this directory. + +File | What's in it? +--- | --- +`linear.py` | Basic helper functions for creating linear layers. +`linear_test.py` | Unit tests for `linear.py`. +`rnn_cell.py` | Cells for recurrent neural networks, e.g., LSTM. +`rnn_cell_test.py` | Unit tests for `rnn_cell.py`. +`rnn.py` | Functions for building recurrent neural networks. +`rnn_test.py` | Unit tests for `rnn.py`. +`seq2seq.py` | Functions for building sequence-to-sequence models. +`seq2seq_test.py` | Unit tests for `seq2seq.py`. +`ptb/` | PTB language model, see the [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/) +`translate/` | Translation model, see the [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/) diff --git a/tensorflow/models/rnn/__init__.py b/tensorflow/models/rnn/__init__.py new file mode 100755 index 0000000000..e69de29bb2 diff --git a/tensorflow/models/rnn/linear.py b/tensorflow/models/rnn/linear.py new file mode 100644 index 0000000000..96278e73e4 --- /dev/null +++ b/tensorflow/models/rnn/linear.py @@ -0,0 +1,49 @@ +"""Basic linear combinations that implicitly generate variables.""" + +import tensorflow as tf + + +def linear(args, output_size, bias, bias_start=0.0, scope=None): + """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. + + Args: + args: a 2D Tensor or a list of 2D, batch x n, Tensors. + output_size: int, second dimension of W[i]. + bias: boolean, whether to add a bias term or not. + bias_start: starting value to initialize the bias; 0 by default. + scope: VariableScope for the created subgraph; defaults to "Linear". + + Returns: + A 2D Tensor with shape [batch x output_size] equal to + sum_i(args[i] * W[i]), where W[i]s are newly created matrices. + + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + assert args + if not isinstance(args, (list, tuple)): + args = [args] + + # Calculate the total size of arguments on dimension 1. + total_arg_size = 0 + shapes = [a.get_shape().as_list() for a in args] + for shape in shapes: + if len(shape) != 2: + raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes)) + if not shape[1]: + raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes)) + else: + total_arg_size += shape[1] + + # Now the computation. + with tf.variable_scope(scope or "Linear"): + matrix = tf.get_variable("Matrix", [total_arg_size, output_size]) + if len(args) == 1: + res = tf.matmul(args[0], matrix) + else: + res = tf.matmul(tf.concat(1, args), matrix) + if not bias: + return res + bias_term = tf.get_variable("Bias", [output_size], + initializer=tf.constant_initializer(bias_start)) + return res + bias_term diff --git a/tensorflow/models/rnn/linear_test.py b/tensorflow/models/rnn/linear_test.py new file mode 100644 index 0000000000..93ef10144f --- /dev/null +++ b/tensorflow/models/rnn/linear_test.py @@ -0,0 +1,35 @@ +# pylint: disable=g-bad-import-order,unused-import +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.models.rnn import linear + + +class LinearTest(tf.test.TestCase): + + def testLinear(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(1.0)): + x = tf.zeros([1, 2]) + l = linear.linear([x], 2, False) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run([l], {x.name: np.array([[1., 2.]])}) + self.assertAllClose(res[0], [[3.0, 3.0]]) + + # Checks prevent you from accidentally creating a shared function. + with self.assertRaises(ValueError) as exc: + l1 = linear.linear([x], 2, False) + self.assertEqual(exc.exception.message[:12], "Over-sharing") + + # But you can create a new one in a new scope and share the variables. + with tf.variable_scope("l1") as new_scope: + l1 = linear.linear([x], 2, False) + with tf.variable_scope(new_scope, reuse=True): + linear.linear([l1], 2, False) + self.assertEqual(len(tf.trainable_variables()), 2) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/models/rnn/ptb/BUILD b/tensorflow/models/rnn/ptb/BUILD new file mode 100644 index 0000000000..56d459a0f1 --- /dev/null +++ b/tensorflow/models/rnn/ptb/BUILD @@ -0,0 +1,49 @@ +# Description: +# Python support for TensorFlow. + +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "reader", + srcs = ["reader.py"], + deps = ["//tensorflow:tensorflow_py"], +) + +py_test( + name = "reader_test", + srcs = ["reader_test.py"], + deps = [ + ":reader", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "ptb_word_lm", + srcs = [ + "ptb_word_lm.py", + ], + deps = [ + ":reader", + "//tensorflow:tensorflow_py", + "//tensorflow/models/rnn", + "//tensorflow/models/rnn:rnn_cell", + "//tensorflow/models/rnn:seq2seq", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/models/rnn/ptb/__init__.py b/tensorflow/models/rnn/ptb/__init__.py new file mode 100755 index 0000000000..e69de29bb2 diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py new file mode 100644 index 0000000000..e28d3bf78c --- /dev/null +++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py @@ -0,0 +1,292 @@ +"""Example / benchmark for building a PTB LSTM model. + +Trains the model described in: +(Zaremba, et. al.) Recurrent Neural Network Regularization +http://arxiv.org/abs/1409.2329 + +The data required for this example is in the data/ dir of the +PTB dataset from Tomas Mikolov's webpage: + +http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz + +There are 3 supported model configurations: +=========================================== +| config | epochs | train | valid | test +=========================================== +| small | 13 | 37.99 | 121.39 | 115.91 +| medium | 39 | 48.45 | 86.16 | 82.07 +| large | 55 | 37.87 | 82.62 | 78.29 +The exact results may vary depending on the random initialization. + +The hyperparameters used in the model: +- init_scale - the initial scale of the weights +- learning_rate - the initial value of the learning rate +- max_grad_norm - the maximum permissible norm of the gradient +- num_layers - the number of LSTM layers +- num_steps - the number of unrolled steps of LSTM +- hidden_size - the number of LSTM units +- max_epoch - the number of epochs trained with the initial learning rate +- max_max_epoch - the total number of epochs for training +- keep_prob - the probability of keeping weights in the dropout layer +- lr_decay - the decay of the learning rate for each epoch after "max_epoch" +- batch_size - the batch size + +To compile on CPU: + bazel build -c opt tensorflow/models/rnn/ptb:ptb_word_lm +To compile on GPU: + bazel build -c opt tensorflow --config=cuda \ + tensorflow/models/rnn/ptb:ptb_word_lm +To run: + ./bazel-bin/.../ptb_word_lm \ + --data_path=/tmp/simple-examples/data/ --alsologtostderr + +""" + +import time + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.models.rnn import rnn_cell +from tensorflow.models.rnn import seq2seq +from tensorflow.models.rnn.ptb import reader + +flags = tf.flags +logging = tf.logging + +flags.DEFINE_string( + "model", "small", + "A type of model. Possible options are: small, medium, large.") +flags.DEFINE_string("data_path", None, "data_path") + +FLAGS = flags.FLAGS + + +class PTBModel(object): + """The PTB model.""" + + def __init__(self, is_training, config): + self.batch_size = batch_size = config.batch_size + self.num_steps = num_steps = config.num_steps + size = config.hidden_size + vocab_size = config.vocab_size + + self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps]) + self._targets = tf.placeholder(tf.int32, [batch_size, num_steps]) + + # Slightly better results can be obtained with forget gate biases + # initialized to 1 but the hyperparameters of the model would need to be + # different than reported in the paper. + lstm_cell = rnn_cell.BasicLSTMCell(size, forget_bias=0.0) + if is_training and config.keep_prob < 1: + lstm_cell = rnn_cell.DropoutWrapper( + lstm_cell, output_keep_prob=config.keep_prob) + cell = rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers) + + self._initial_state = cell.zero_state(batch_size, tf.float32) + + with tf.device("/cpu:0"): + embedding = tf.get_variable("embedding", [vocab_size, size]) + inputs = tf.split( + 1, num_steps, tf.nn.embedding_lookup(embedding, self._input_data)) + inputs = [tf.squeeze(input_, [1]) for input_ in inputs] + + if is_training and config.keep_prob < 1: + inputs = [tf.nn.dropout(input_, config.keep_prob) for input_ in inputs] + + # Simplified version of tensorflow.models.rnn.rnn.py's rnn(). + # This builds an unrolled LSTM for tutorial purposes only. + # In general, use the rnn() or state_saving_rnn() from rnn.py. + # + # The alternative version of the code below is: + # + # from tensorflow.models.rnn import rnn + # outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state) + outputs = [] + states = [] + state = self._initial_state + with tf.variable_scope("RNN"): + for time_step, input_ in enumerate(inputs): + if time_step > 0: tf.get_variable_scope().reuse_variables() + (cell_output, state) = cell(input_, state) + outputs.append(cell_output) + states.append(state) + + output = tf.reshape(tf.concat(1, outputs), [-1, size]) + logits = tf.nn.xw_plus_b(output, + tf.get_variable("softmax_w", [size, vocab_size]), + tf.get_variable("softmax_b", [vocab_size])) + loss = seq2seq.sequence_loss_by_example([logits], + [tf.reshape(self._targets, -1)], + [tf.ones([batch_size * num_steps])], + vocab_size) + self._cost = cost = tf.reduce_sum(loss) / batch_size + self._final_state = states[-1] + + if not is_training: + return + + self._lr = tf.Variable(0.0, trainable=False) + tvars = tf.trainable_variables() + grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), + config.max_grad_norm) + optimizer = tf.train.GradientDescentOptimizer(self.lr) + self._train_op = optimizer.apply_gradients(zip(grads, tvars)) + + def assign_lr(self, session, lr_value): + session.run(tf.assign(self.lr, lr_value)) + + @property + def input_data(self): + return self._input_data + + @property + def targets(self): + return self._targets + + @property + def initial_state(self): + return self._initial_state + + @property + def cost(self): + return self._cost + + @property + def final_state(self): + return self._final_state + + @property + def lr(self): + return self._lr + + @property + def train_op(self): + return self._train_op + + +class SmallConfig(object): + """Small config.""" + init_scale = 0.1 + learning_rate = 1.0 + max_grad_norm = 5 + num_layers = 2 + num_steps = 20 + hidden_size = 200 + max_epoch = 4 + max_max_epoch = 13 + keep_prob = 1.0 + lr_decay = 0.5 + batch_size = 20 + vocab_size = 10000 + + +class MediumConfig(object): + """Medium config.""" + init_scale = 0.05 + learning_rate = 1.0 + max_grad_norm = 5 + num_layers = 2 + num_steps = 35 + hidden_size = 650 + max_epoch = 6 + max_max_epoch = 39 + keep_prob = 0.5 + lr_decay = 0.8 + batch_size = 20 + vocab_size = 10000 + + +class LargeConfig(object): + """Large config.""" + init_scale = 0.04 + learning_rate = 1.0 + max_grad_norm = 10 + num_layers = 2 + num_steps = 35 + hidden_size = 1500 + max_epoch = 14 + max_max_epoch = 55 + keep_prob = 0.35 + lr_decay = 1 / 1.15 + batch_size = 20 + vocab_size = 10000 + + +def run_epoch(session, m, data, eval_op, verbose=False): + """Runs the model on the given data.""" + epoch_size = ((len(data) / m.batch_size) - 1) / m.num_steps + start_time = time.time() + costs = 0.0 + iters = 0 + state = m.initial_state.eval() + for step, (x, y) in enumerate(reader.ptb_iterator(data, m.batch_size, + m.num_steps)): + cost, state, _ = session.run([m.cost, m.final_state, eval_op], + {m.input_data: x, + m.targets: y, + m.initial_state: state}) + costs += cost + iters += m.num_steps + + if verbose and step % (epoch_size / 10) == 10: + print("%.3f perplexity: %.3f speed: %.0f wps" % + (step * 1.0 / epoch_size, np.exp(costs / iters), + iters * m.batch_size / (time.time() - start_time))) + + return np.exp(costs / iters) + + +def get_config(): + if FLAGS.model == "small": + return SmallConfig() + elif FLAGS.model == "medium": + return MediumConfig() + elif FLAGS.model == "large": + return LargeConfig() + else: + raise ValueError("Invalid model: %s", FLAGS.model) + + +def main(unused_args): + if not FLAGS.data_path: + raise ValueError("Must set --data_path to PTB data directory") + + raw_data = reader.ptb_raw_data(FLAGS.data_path) + train_data, valid_data, test_data, _ = raw_data + + config = get_config() + eval_config = get_config() + eval_config.batch_size = 1 + eval_config.num_steps = 1 + + with tf.Graph().as_default(), tf.Session() as session: + initializer = tf.random_uniform_initializer(-config.init_scale, + config.init_scale) + with tf.variable_scope("model", reuse=None, initializer=initializer): + m = PTBModel(is_training=True, config=config) + with tf.variable_scope("model", reuse=True, initializer=initializer): + mvalid = PTBModel(is_training=False, config=config) + mtest = PTBModel(is_training=False, config=eval_config) + + tf.initialize_all_variables().run() + + for i in range(config.max_max_epoch): + lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0) + m.assign_lr(session, config.learning_rate * lr_decay) + + print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr))) + train_perplexity = run_epoch(session, m, train_data, m.train_op, + verbose=True) + print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity)) + valid_perplexity = run_epoch(session, mvalid, valid_data, tf.no_op()) + print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity)) + + test_perplexity = run_epoch(session, mtest, test_data, tf.no_op()) + print("Test Perplexity: %.3f" % test_perplexity) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/models/rnn/ptb/reader.py b/tensorflow/models/rnn/ptb/reader.py new file mode 100644 index 0000000000..9a0db9c525 --- /dev/null +++ b/tensorflow/models/rnn/ptb/reader.py @@ -0,0 +1,105 @@ +# pylint: disable=unused-import,g-bad-import-order + +"""Utilities for parsing PTB text files.""" +import collections +import os +import sys +import time + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.python.platform import gfile + + +def _read_words(filename): + with gfile.GFile(filename, "r") as f: + return f.read().replace("\n", "").split() + + +def _build_vocab(filename): + data = _read_words(filename) + + counter = collections.Counter(data) + count_pairs = sorted(counter.items(), key=lambda x: -x[1]) + + words, _ = zip(*count_pairs) + word_to_id = dict(zip(words, range(len(words)))) + + return word_to_id + + +def _file_to_word_ids(filename, word_to_id): + data = _read_words(filename) + return [word_to_id[word] for word in data] + + +def ptb_raw_data(data_path=None): + """Load PTB raw data from data directory "data_path". + + Reads PTB text files, converts strings to integer ids, + and performs mini-batching of the inputs. + + The PTB dataset comes from Tomas Mikolov's webpage: + + http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz + + Args: + data_path: string path to the directory where simple-examples.tgz has + been extracted. + + Returns: + tuple (train_data, valid_data, test_data, vocabulary) + where each of the data objects can be passed to PTBIterator. + """ + + train_path = os.path.join(data_path, "ptb.train.txt") + valid_path = os.path.join(data_path, "ptb.valid.txt") + test_path = os.path.join(data_path, "ptb.test.txt") + + word_to_id = _build_vocab(train_path) + train_data = _file_to_word_ids(train_path, word_to_id) + valid_data = _file_to_word_ids(valid_path, word_to_id) + test_data = _file_to_word_ids(test_path, word_to_id) + vocabulary = len(word_to_id) + return train_data, valid_data, test_data, vocabulary + + +def ptb_iterator(raw_data, batch_size, num_steps): + """Iterate on the raw PTB data. + + This generates batch_size pointers into the raw PTB data, and allows + minibatch iteration along these pointers. + + Args: + raw_data: one of the raw data outputs from ptb_raw_data. + batch_size: int, the batch size. + num_steps: int, the number of unrolls. + + Yields: + Pairs of the batched data, each a matrix of shape [batch_size, num_steps]. + The second element of the tuple is the same data time-shifted to the + right by one. + + Raises: + ValueError: if batch_size or num_steps are too high. + """ + raw_data = np.array(raw_data, dtype=np.int32) + + data_len = len(raw_data) + batch_len = data_len / batch_size + data = np.zeros([batch_size, batch_len], dtype=np.int32) + for i in range(batch_size): + data[i] = raw_data[batch_len * i:batch_len * (i + 1)] + + epoch_size = (batch_len - 1) / num_steps + + if epoch_size == 0: + raise ValueError("epoch_size == 0, decrease batch_size or num_steps") + + for i in range(epoch_size): + x = data[:, i*num_steps:(i+1)*num_steps] + y = data[:, i*num_steps+1:(i+1)*num_steps+1] + yield (x, y) diff --git a/tensorflow/models/rnn/ptb/reader_test.py b/tensorflow/models/rnn/ptb/reader_test.py new file mode 100644 index 0000000000..c722cdb939 --- /dev/null +++ b/tensorflow/models/rnn/ptb/reader_test.py @@ -0,0 +1,47 @@ +"""Tests for tensorflow.models.ptb_lstm.ptb_reader.""" + +import os.path + +# pylint: disable=g-bad-import-order,unused-import +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.models.rnn.ptb import reader +from tensorflow.python.platform import gfile + + +class PtbReaderTest(tf.test.TestCase): + + def setUp(self): + self._string_data = "\n".join( + [" hello there i am", + " rain as day", + " want some cheesy puffs ?"]) + + def testPtbRawData(self): + tmpdir = tf.test.get_temp_dir() + for suffix in "train", "valid", "test": + filename = os.path.join(tmpdir, "ptb.%s.txt" % suffix) + with gfile.GFile(filename, "w") as fh: + fh.write(self._string_data) + # Smoke test + output = reader.ptb_raw_data(tmpdir) + self.assertEqual(len(output), 4) + + def testPtbIterator(self): + raw_data = [4, 3, 2, 1, 0, 5, 6, 1, 1, 1, 1, 0, 3, 4, 1] + batch_size = 3 + num_steps = 2 + output = list(reader.ptb_iterator(raw_data, batch_size, num_steps)) + self.assertEqual(len(output), 2) + o1, o2 = (output[0], output[1]) + self.assertEqual(o1[0].shape, (batch_size, num_steps)) + self.assertEqual(o1[1].shape, (batch_size, num_steps)) + self.assertEqual(o2[0].shape, (batch_size, num_steps)) + self.assertEqual(o2[1].shape, (batch_size, num_steps)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/models/rnn/rnn.py b/tensorflow/models/rnn/rnn.py new file mode 100644 index 0000000000..24582bcae7 --- /dev/null +++ b/tensorflow/models/rnn/rnn.py @@ -0,0 +1,128 @@ +"""RNN helpers for TensorFlow models.""" + +import tensorflow as tf + +from tensorflow.models.rnn import rnn_cell +from tensorflow.python.ops import control_flow_ops + + +def rnn(cell, inputs, initial_state=None, dtype=None, + sequence_length=None, scope=None): + """Creates a recurrent neural network specified by RNNCell "cell". + + The simplest form of RNN network generated is: + state = cell.zero_state(...) + outputs = [] + states = [] + for input_ in inputs: + output, state = cell(input_, state) + outputs.append(output) + states.append(state) + return (outputs, states) + + However, a few other options are available: + + An initial state can be provided. + If sequence_length is provided, dynamic calculation is performed. + + Dynamic calculation returns, at time t: + (t >= max(sequence_length) + ? (zeros(output_shape), zeros(state_shape)) + : cell(input, state) + + Thus saving computational time when unrolling past the max sequence length. + + Args: + cell: An instance of RNNCell. + inputs: A length T list of inputs, each a vector with shape [batch_size]. + initial_state: (optional) An initial state for the RNN. This must be + a tensor of appropriate type and shape [batch_size x cell.state_size]. + dtype: (optional) The data type for the initial state. Required if + initial_state is not provided. + sequence_length: An int64 vector (tensor) size [batch_size]. + scope: VariableScope for the created subgraph; defaults to "RNN". + + Returns: + A pair (outputs, states) where: + outputs is a length T list of outputs (one for each input) + states is a length T list of states (one state following each input) + + Raises: + TypeError: If "cell" is not an instance of RNNCell. + ValueError: If inputs is None or an empty list. + """ + + if not isinstance(cell, rnn_cell.RNNCell): + raise TypeError("cell must be an instance of RNNCell") + if not isinstance(inputs, list): + raise TypeError("inputs must be a list") + if not inputs: + raise ValueError("inputs must not be empty") + + outputs = [] + states = [] + with tf.variable_scope(scope or "RNN"): + batch_size = tf.shape(inputs[0])[0] + if initial_state is not None: + state = initial_state + else: + if not dtype: + raise ValueError("If no initial_state is provided, dtype must be.") + state = cell.zero_state(batch_size, dtype) + + if sequence_length: # Prepare variables + zero_output_state = ( + tf.zeros(tf.pack([batch_size, cell.output_size]), + inputs[0].dtype), + tf.zeros(tf.pack([batch_size, cell.state_size]), + state.dtype)) + max_sequence_length = tf.reduce_max(sequence_length) + + output_state = (None, None) + for time, input_ in enumerate(inputs): + if time > 0: + tf.get_variable_scope().reuse_variables() + output_state = cell(input_, state) + if sequence_length: + (output, state) = control_flow_ops.cond( + time >= max_sequence_length, + lambda: zero_output_state, lambda: output_state) + else: + (output, state) = output_state + + outputs.append(output) + states.append(state) + + return (outputs, states) + + +def state_saving_rnn(cell, inputs, state_saver, state_name, + sequence_length=None, scope=None): + """RNN that accepts a state saver for time-truncated RNN calculation. + + Args: + cell: An instance of RNNCell. + inputs: A length T list of inputs, each a vector with shape [batch_size]. + state_saver: A StateSaver object. + state_name: The name to use with the state_saver. + sequence_length: (optional) An int64 vector (tensor) size [batch_size]. + See the documentation for rnn() for more details about sequence_length. + scope: VariableScope for the created subgraph; defaults to "RNN". + + Returns: + A pair (outputs, states) where: + outputs is a length T list of outputs (one for each input) + states is a length T list of states (one state following each input) + + Raises: + TypeError: If "cell" is not an instance of RNNCell. + ValueError: If inputs is None or an empty list. + """ + initial_state = state_saver.State(state_name) + (outputs, states) = rnn(cell, inputs, initial_state=initial_state, + sequence_length=sequence_length, scope=scope) + save_state = state_saver.SaveState(state_name, states[-1]) + with tf.control_dependencies([save_state]): + outputs[-1] = tf.identity(outputs[-1]) + + return (outputs, states) diff --git a/tensorflow/models/rnn/rnn_cell.py b/tensorflow/models/rnn/rnn_cell.py new file mode 100644 index 0000000000..55d417fc2b --- /dev/null +++ b/tensorflow/models/rnn/rnn_cell.py @@ -0,0 +1,605 @@ +"""Module for constructing RNN Cells.""" + +import math + +import tensorflow as tf + +from tensorflow.models.rnn import linear + + +class RNNCell(object): + """Abstract object representing an RNN cell. + + An RNN cell, in the most abstract setting, is anything that has + a state -- a vector of floats of size self.state_size -- and performs some + operation that takes inputs of size self.input_size. This operation + results in an output of size self.output_size and a new state. + + This module provides a number of basic commonly used RNN cells, such as + LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number + of operators that allow add dropouts, projections, or embeddings for inputs. + Constructing multi-layer cells is supported by a super-class, MultiRNNCell, + defined later. Every RNNCell must have the properties below and and + implement __call__ with the following signature. + """ + + def __call__(self, inputs, state, scope=None): + """Run this RNN cell on inputs, starting from the given state. + + Args: + inputs: 2D Tensor with shape [batch_size x self.input_size]. + state: 2D Tensor with shape [batch_size x self.state_size]. + scope: VariableScope for the created subgraph; defaults to class name. + + Returns: + A pair containing: + - Output: A 2D Tensor with shape [batch_size x self.output_size] + - New state: A 2D Tensor with shape [batch_size x self.state_size]. + """ + raise NotImplementedError("Abstract method") + + @property + def input_size(self): + """Integer: size of inputs accepted by this cell.""" + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + @property + def state_size(self): + """Integer: size of state used by this cell.""" + raise NotImplementedError("Abstract method") + + def zero_state(self, batch_size, dtype): + """Return state tensor (shape [batch_size x state_size]) filled with 0. + + Args: + batch_size: int, float, or unit Tensor representing the batch size. + dtype: the data type to use for the state. + + Returns: + A 2D Tensor of shape [batch_size x state_size] filled with zeros. + """ + zeros = tf.zeros(tf.pack([batch_size, self.state_size]), dtype=dtype) + # The reshape below is a no-op, but it allows shape inference of shape[1]. + return tf.reshape(zeros, [-1, self.state_size]) + + +class BasicRNNCell(RNNCell): + """The most basic RNN cell.""" + + def __init__(self, num_units): + self._num_units = num_units + + @property + def input_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + @property + def state_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None): + """Most basic RNN: output = new_state = tanh(W * input + U * state + B).""" + with tf.variable_scope(scope or type(self).__name__): # "BasicRNNCell" + output = tf.tanh(linear.linear([inputs, state], self._num_units, True)) + return output, output + + +class GRUCell(RNNCell): + """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" + + def __init__(self, num_units): + self._num_units = num_units + + @property + def input_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + @property + def state_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None): + """Gated recurrent unit (GRU) with nunits cells.""" + with tf.variable_scope(scope or type(self).__name__): # "GRUCell" + with tf.variable_scope("Gates"): # Reset gate and update gate. + # We start with bias of 1.0 to not reset and not udpate. + r, u = tf.split(1, 2, linear.linear([inputs, state], + 2 * self._num_units, True, 1.0)) + r, u = tf.sigmoid(r), tf.sigmoid(u) + with tf.variable_scope("Candidate"): + c = tf.tanh(linear.linear([inputs, r * state], self._num_units, True)) + new_h = u * state + (1 - u) * c + return new_h, new_h + + +class BasicLSTMCell(RNNCell): + """Basic LSTM recurrent network cell. + + The implementation is based on: http://arxiv.org/pdf/1409.2329v5.pdf. + + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + + Biases of the forget gate are initialized by default to 1 in order to reduce + the scale of forgetting in the beginning of the training. + """ + + def __init__(self, num_units, forget_bias=1.0): + self._num_units = num_units + self._forget_bias = forget_bias + + @property + def input_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + @property + def state_size(self): + return 2 * self._num_units + + def __call__(self, inputs, state, scope=None): + """Long short-term memory cell (LSTM).""" + with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" + # Parameters of gates are concatenated into one multiply for efficiency. + c, h = tf.split(1, 2, state) + concat = linear.linear([inputs, h], 4 * self._num_units, True) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = tf.split(1, 4, concat) + + new_c = c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * tf.tanh(j) + new_h = tf.tanh(new_c) * tf.sigmoid(o) + + return new_h, tf.concat(1, [new_c, new_h]) + + +class LSTMCell(RNNCell): + """Long short-term memory unit (LSTM) recurrent network cell. + + This implementation is based on: + + https://research.google.com/pubs/archive/43905.pdf + + Hasim Sak, Andrew Senior, and Francoise Beaufays. + "Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + + It uses peep-hole connections, optional cell clipping, and an optional + projection layer. + """ + + def __init__(self, num_units, input_size, + use_peepholes=False, cell_clip=None, + initializer=None, num_proj=None, + num_unit_shards=1, num_proj_shards=1): + """Initialize the parameters for an LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell + input_size: int, The dimensionality of the inputs into the LSTM cell + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + num_unit_shards: How to split the weight matrix. If >1, the weight + matrix is stored across num_unit_shards. + Note that num_unit_shards must evenly divide num_units * 4. + num_proj_shards: How to split the projection matrix. If >1, the + projection matrix is stored across num_proj_shards. + Note that num_proj_shards must evenly divide num_proj + (if num_proj is not None). + + Raises: + ValueError: if num_unit_shards doesn't divide 4 * num_units or + num_proj_shards doesn't divide num_proj + """ + self._num_units = num_units + self._input_size = input_size + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._initializer = initializer + self._num_proj = num_proj + self._num_unit_shards = num_unit_shards + self._num_proj_shards = num_proj_shards + + if (num_units * 4) % num_unit_shards != 0: + raise ValueError("num_unit_shards must evently divide 4 * num_units") + if num_proj and num_proj % num_proj_shards != 0: + raise ValueError("num_proj_shards must evently divide num_proj") + + if num_proj: + self._state_size = num_units + num_proj + self._output_size = num_proj + else: + self._state_size = 2 * num_units + self._output_size = num_units + + @property + def input_size(self): + return self._input_size + + @property + def output_size(self): + return self._output_size + + @property + def state_size(self): + return self._state_size + + def __call__(self, input_, state, scope=None): + """Run one step of LSTM. + + Args: + input_: input Tensor, 2D, batch x num_units. + state: state Tensor, 2D, batch x state_size. + scope: VariableScope for the created subgraph; defaults to "LSTMCell". + + Returns: + A tuple containing: + - A 2D, batch x output_dim, Tensor representing the output of the LSTM + after reading "input_" when previous state was "state". + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - A 2D, batch x state_size, Tensor representing the new state of LSTM + after reading "input_" when previous state was "state". + """ + num_proj = self._num_units if self._num_proj is None else self._num_proj + + c_prev = tf.slice(state, [0, 0], [-1, self._num_units]) + m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj]) + + dtype = input_.dtype + + unit_shard_size = (4 * self._num_units) / self._num_unit_shards + + with tf.variable_scope(scope or type(self).__name__): # "LSTMCell" + w = tf.concat( + 1, [tf.get_variable("W_%d" % i, + shape=[self.input_size + num_proj, + unit_shard_size], + initializer=self._initializer, + dtype=dtype) + for i in range(self._num_unit_shards)]) + + b = tf.get_variable( + "B", shape=[4 * self._num_units], + initializer=tf.zeros_initializer, dtype=dtype) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + cell_inputs = tf.concat(1, [input_, m_prev]) + i, j, f, o = tf.split(1, 4, tf.nn.bias_add(tf.matmul(cell_inputs, w), b)) + + # Diagonal connections + if self._use_peepholes: + w_f_diag = tf.get_variable( + "W_F_diag", shape=[self._num_units], dtype=dtype) + w_i_diag = tf.get_variable( + "W_I_diag", shape=[self._num_units], dtype=dtype) + w_o_diag = tf.get_variable( + "W_O_diag", shape=[self._num_units], dtype=dtype) + + if self._use_peepholes: + c = (tf.sigmoid(f + 1 + w_f_diag * c_prev) * c_prev + + tf.sigmoid(i + w_i_diag * c_prev) * tf.tanh(j)) + else: + c = (tf.sigmoid(f + 1) * c_prev + tf.sigmoid(i) * tf.tanh(j)) + + if self._cell_clip is not None: + c = tf.clip_by_value(c, -self._cell_clip, self._cell_clip) + + if self._use_peepholes: + m = tf.sigmoid(o + w_o_diag * c) * tf.tanh(c) + else: + m = tf.sigmoid(o) * tf.tanh(c) + + if self._num_proj is not None: + proj_shard_size = self._num_proj / self._num_proj_shards + w_proj = tf.concat( + 1, [tf.get_variable("W_P_%d" % i, + shape=[self._num_units, proj_shard_size], + initializer=self._initializer, dtype=dtype) + for i in range(self._num_proj_shards)]) + # TODO(ebrevdo), use matmulsum + m = tf.matmul(m, w_proj) + + return m, tf.concat(1, [c, m]) + + +class OutputProjectionWrapper(RNNCell): + """Operator adding an output projection to the given cell. + + Note: in many cases it may be more efficient to not use this wrapper, + but instead concatenate the whole sequence of your outputs in time, + do the projection on this batch-concated sequence, then split it + if needed or directly feed into a softmax. + """ + + def __init__(self, cell, output_size): + """Create a cell with output projection. + + Args: + cell: an RNNCell, a projection to output_size is added to it. + output_size: integer, the size of the output after projection. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if output_size is not positive. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not RNNCell.") + if output_size < 1: + raise ValueError("Parameter output_size must be > 0: %d." % output_size) + self._cell = cell + self._output_size = output_size + + @property + def input_size(self): + return self._cell.input_size + + @property + def output_size(self): + return self._output_size + + @property + def state_size(self): + return self._cell.state_size + + def __call__(self, inputs, state, scope=None): + """Run the cell and output projection on inputs, starting from state.""" + output, res_state = self._cell(inputs, state) + # Default scope: "OutputProjectionWrapper" + with tf.variable_scope(scope or type(self).__name__): + projected = linear.linear(output, self._output_size, True) + return projected, res_state + + +class InputProjectionWrapper(RNNCell): + """Operator adding an input projection to the given cell. + + Note: in many cases it may be more efficient to not use this wrapper, + but instead concatenate the whole sequence of your inputs in time, + do the projection on this batch-concated sequence, then split it. + """ + + def __init__(self, cell, input_size): + """Create a cell with input projection. + + Args: + cell: an RNNCell, a projection of inputs is added before it. + input_size: integer, the size of the inputs before projection. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if input_size is not positive. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not RNNCell.") + if input_size < 1: + raise ValueError("Parameter input_size must be > 0: %d." % input_size) + self._cell = cell + self._input_size = input_size + + @property + def input_size(self): + return self._input_size + + @property + def output_size(self): + return self._cell.output_size + + @property + def state_size(self): + return self._cell.state_size + + def __call__(self, inputs, state, scope=None): + """Run the input projection and then the cell.""" + # Default scope: "InputProjectionWrapper" + with tf.variable_scope(scope or type(self).__name__): + projected = linear.linear(inputs, self._cell.input_size, True) + return self._cell(projected, state) + + +class DropoutWrapper(RNNCell): + """Operator adding dropout to inputs and outputs of the given cell.""" + + def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, + seed=None): + """Create a cell with added input and/or output dropout. + + Dropout is never used on the state. + + Args: + cell: an RNNCell, a projection to output_size is added to it. + input_keep_prob: unit Tensor or float between 0 and 1, input keep + probability; if it is float and 1, no input dropout will be added. + output_keep_prob: unit Tensor or float between 0 and 1, output keep + probability; if it is float and 1, no output dropout will be added. + seed: (optional) integer, the randomness seed. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if keep_prob is not between 0 and 1. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not a RNNCell.") + if (isinstance(input_keep_prob, float) and + not (input_keep_prob >= 0.0 and input_keep_prob <= 1.0)): + raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d" + % input_keep_prob) + if (isinstance(output_keep_prob, float) and + not (output_keep_prob >= 0.0 and output_keep_prob <= 1.0)): + raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d" + % output_keep_prob) + self._cell = cell + self._input_keep_prob = input_keep_prob + self._output_keep_prob = output_keep_prob + self._seed = seed + + @property + def input_size(self): + return self._cell.input_size + + @property + def output_size(self): + return self._cell.output_size + + @property + def state_size(self): + return self._cell.state_size + + def __call__(self, inputs, state): + """Run the cell with the declared dropouts.""" + if (not isinstance(self._input_keep_prob, float) or + self._input_keep_prob < 1): + inputs = tf.nn.dropout(inputs, self._input_keep_prob, seed=self._seed) + output, new_state = self._cell(inputs, state) + if (not isinstance(self._output_keep_prob, float) or + self._output_keep_prob < 1): + output = tf.nn.dropout(output, self._output_keep_prob, seed=self._seed) + return output, new_state + + +class EmbeddingWrapper(RNNCell): + """Operator adding input embedding to the given cell. + + Note: in many cases it may be more efficient to not use this wrapper, + but instead concatenate the whole sequence of your inputs in time, + do the embedding on this batch-concated sequence, then split it and + feed into your RNN. + """ + + def __init__(self, cell, embedding_classes=0, embedding=None, + initializer=None): + """Create a cell with an added input embedding. + + Args: + cell: an RNNCell, an embedding will be put before its inputs. + embedding_classes: integer, how many symbols will be embedded. + embedding: Variable, the embedding to use; if None, a new embedding + will be created; if set, then embedding_classes is not required. + initializer: an initializer to use when creating the embedding; + if None, the initializer from variable scope or a default one is used. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if embedding_classes is not positive. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not RNNCell.") + if embedding_classes < 1 and embedding is None: + raise ValueError("Pass embedding or embedding_classes must be > 0: %d." + % embedding_classes) + if embedding_classes > 0 and embedding is not None: + if embedding.size[0] != embedding_classes: + raise ValueError("You declared embedding_classes=%d but passed an " + "embedding for %d classes." % (embedding.size[0], + embedding_classes)) + if embedding.size[1] != cell.input_size: + raise ValueError("You passed embedding with output size %d and a cell" + " that accepts size %d." % (embedding.size[1], + cell.input_size)) + self._cell = cell + self._embedding_classes = embedding_classes + self._embedding = embedding + self._initializer = initializer + + @property + def input_size(self): + return 1 + + @property + def output_size(self): + return self._cell.output_size + + @property + def state_size(self): + return self._cell.state_size + + def __call__(self, inputs, state, scope=None): + """Run the cell on embedded inputs.""" + with tf.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper" + with tf.device("/cpu:0"): + if self._embedding: + embedding = self._embedding + else: + if self._initializer: + initializer = self._initializer + elif tf.get_variable_scope().initializer: + initializer = tf.get_variable_scope().initializer + else: + # Default initializer for embeddings should have variance=1. + sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. + initializer = tf.random_uniform_initializer(-sqrt3, sqrt3) + embedding = tf.get_variable("embedding", [self._embedding_classes, + self._cell.input_size], + initializer=initializer) + embedded = tf.nn.embedding_lookup(embedding, tf.reshape(inputs, [-1])) + return self._cell(embedded, state) + + +class MultiRNNCell(RNNCell): + """RNN cell composed sequentially of multiple simple cells.""" + + def __init__(self, cells): + """Create a RNN cell composed sequentially of a number of RNNCells. + + Args: + cells: list of RNNCells that will be composed in this order. + + Raises: + ValueError: if cells is empty (not allowed) or if their sizes don't match. + """ + if not cells: + raise ValueError("Must specify at least one cell for MultiRNNCell.") + for i in xrange(len(cells) - 1): + if cells[i + 1].input_size != cells[i].output_size: + raise ValueError("In MultiRNNCell, the input size of each next" + " cell must match the output size of the previous one." + " Mismatched output size in cell %d." % i) + self._cells = cells + + @property + def input_size(self): + return self._cells[0].input_size + + @property + def output_size(self): + return self._cells[-1].output_size + + @property + def state_size(self): + return sum([cell.state_size for cell in self._cells]) + + def __call__(self, inputs, state, scope=None): + """Run this multi-layer cell on inputs, starting from state.""" + with tf.variable_scope(scope or type(self).__name__): # "MultiRNNCell" + cur_state_pos = 0 + cur_inp = inputs + new_states = [] + for i, cell in enumerate(self._cells): + with tf.variable_scope("Cell%d" % i): + cur_state = tf.slice(state, [0, cur_state_pos], [-1, cell.state_size]) + cur_state_pos += cell.state_size + cur_inp, new_state = cell(cur_inp, cur_state) + new_states.append(new_state) + return cur_inp, tf.concat(1, new_states) diff --git a/tensorflow/models/rnn/rnn_cell_test.py b/tensorflow/models/rnn/rnn_cell_test.py new file mode 100644 index 0000000000..8b4b209028 --- /dev/null +++ b/tensorflow/models/rnn/rnn_cell_test.py @@ -0,0 +1,154 @@ +"""Tests for RNN cells.""" + +# pylint: disable=g-bad-import-order,unused-import +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.models.rnn import rnn_cell + + +class RNNCellTest(tf.test.TestCase): + + def testBasicRNNCell(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 2]) + m = tf.zeros([1, 2]) + g, _ = rnn_cell.BasicRNNCell(2)(x, m) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run([g], {x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]])}) + self.assertEqual(res[0].shape, (1, 2)) + + def testGRUCell(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 2]) + m = tf.zeros([1, 2]) + g, _ = rnn_cell.GRUCell(2)(x, m) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run([g], {x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]])}) + # Smoke test + self.assertAllClose(res[0], [[0.175991, 0.175991]]) + + def testBasicLSTMCell(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 2]) + m = tf.zeros([1, 8]) + g, out_m = rnn_cell.MultiRNNCell([rnn_cell.BasicLSTMCell(2)] * 2)(x, m) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run([g, out_m], {x.name: np.array([[1., 1.]]), + m.name: 0.1 * np.ones([1, 8])}) + self.assertEqual(len(res), 2) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) + expected_mem = np.array([[0.68967271, 0.68967271, + 0.44848421, 0.44848421, + 0.39897051, 0.39897051, + 0.24024698, 0.24024698]]) + self.assertAllClose(res[1], expected_mem) + + def testLSTMCell(self): + with self.test_session() as sess: + num_units = 8 + num_proj = 6 + state_size = num_units + num_proj + batch_size = 3 + input_size = 2 + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([batch_size, input_size]) + m = tf.zeros([batch_size, state_size]) + output, state = rnn_cell.LSTMCell( + num_units=num_units, input_size=input_size, num_proj=num_proj)(x, m) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run([output, state], + {x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]), + m.name: 0.1 * np.ones((batch_size, state_size))}) + self.assertEqual(len(res), 2) + # The numbers in results were not calculated, this is mostly just a + # smoke test. + self.assertEqual(res[0].shape, (batch_size, num_proj)) + self.assertEqual(res[1].shape, (batch_size, state_size)) + # Different inputs so different outputs and states + for i in range(1, batch_size): + self.assertTrue( + float(np.linalg.norm((res[0][0,:] - res[0][i,:]))) > 1e-6) + self.assertTrue( + float(np.linalg.norm((res[1][0,:] - res[1][i,:]))) > 1e-6) + + def testOutputProjectionWrapper(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 3]) + m = tf.zeros([1, 3]) + cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(3), 2) + g, new_m = cell(x, m) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run([g, new_m], {x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]])}) + self.assertEqual(res[1].shape, (1, 3)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.231907, 0.231907]]) + + def testInputProjectionWrapper(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 2]) + m = tf.zeros([1, 3]) + cell = rnn_cell.InputProjectionWrapper(rnn_cell.GRUCell(3), 2) + g, new_m = cell(x, m) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run([g, new_m], {x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]])}) + self.assertEqual(res[1].shape, (1, 3)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) + + def testDropoutWrapper(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 3]) + m = tf.zeros([1, 3]) + keep = tf.zeros([1]) + 1 + g, new_m = rnn_cell.DropoutWrapper(rnn_cell.GRUCell(3), + keep, keep)(x, m) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run([g, new_m], {x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]])}) + self.assertEqual(res[1].shape, (1, 3)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) + + def testEmbeddingWrapper(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 1], dtype=tf.int32) + m = tf.zeros([1, 2]) + g, new_m = rnn_cell.EmbeddingWrapper(rnn_cell.GRUCell(2), 3)(x, m) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run([g, new_m], {x.name: np.array([[1]]), + m.name: np.array([[0.1, 0.1]])}) + self.assertEqual(res[1].shape, (1, 2)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.17139, 0.17139]]) + + def testMultiRNNCell(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 2]) + m = tf.zeros([1, 4]) + _, ml = rnn_cell.MultiRNNCell([rnn_cell.GRUCell(2)] * 2)(x, m) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run(ml, {x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1, 0.1]])}) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res, [[0.175991, 0.175991, + 0.13248, 0.13248]]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/models/rnn/rnn_test.py b/tensorflow/models/rnn/rnn_test.py new file mode 100644 index 0000000000..378315d296 --- /dev/null +++ b/tensorflow/models/rnn/rnn_test.py @@ -0,0 +1,472 @@ +"""Tests for rnn module.""" + +# pylint: disable=g-bad-import-order,unused-import +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.models.rnn import rnn +from tensorflow.models.rnn import rnn_cell + + +class Plus1RNNCell(rnn_cell.RNNCell): + """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" + + @property + def output_size(self): + return 5 + + @property + def state_size(self): + return 5 + + def __call__(self, input_, state): + return (input_ + 1, state + 1) + + +class TestStateSaver(object): + + def __init__(self, batch_size, state_size): + self._batch_size = batch_size + self._state_size = state_size + + def State(self, _): + return tf.zeros(tf.pack([self._batch_size, self._state_size])) + + def SaveState(self, _, state): + self.saved_state = state + return tf.identity(state) + + +class RNNTest(tf.test.TestCase): + + def setUp(self): + self._seed = 23489 + np.random.seed(self._seed) + + def testRNN(self): + cell = Plus1RNNCell() + batch_size = 2 + inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10 + outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32) + self.assertEqual(len(outputs), len(inputs)) + for out, inp in zip(outputs, inputs): + self.assertEqual(out.get_shape(), inp.get_shape()) + self.assertEqual(out.dtype, inp.dtype) + + with self.test_session(use_gpu=False) as sess: + input_value = np.random.randn(batch_size, 5) + values = sess.run(outputs + [states[-1]], + feed_dict={inputs[0]: input_value}) + + # Outputs + for v in values[:-1]: + self.assertAllClose(v, input_value + 1.0) + + # Final state + self.assertAllClose( + values[-1], 10.0*np.ones((batch_size, 5), dtype=np.float32)) + + def testDropout(self): + cell = Plus1RNNCell() + full_dropout_cell = rnn_cell.DropoutWrapper( + cell, input_keep_prob=1e-12, seed=0) + batch_size = 2 + inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10 + with tf.variable_scope("share_scope"): + outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32) + with tf.variable_scope("drop_scope"): + dropped_outputs, _ = rnn.rnn(full_dropout_cell, inputs, dtype=tf.float32) + self.assertEqual(len(outputs), len(inputs)) + for out, inp in zip(outputs, inputs): + self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list()) + self.assertEqual(out.dtype, inp.dtype) + + with self.test_session(use_gpu=False) as sess: + input_value = np.random.randn(batch_size, 5) + values = sess.run(outputs + [states[-1]], + feed_dict={inputs[0]: input_value}) + full_dropout_values = sess.run(dropped_outputs, + feed_dict={inputs[0]: input_value}) + + for v in values[:-1]: + self.assertAllClose(v, input_value + 1.0) + for d_v in full_dropout_values[:-1]: # Add 1.0 to dropped_out (all zeros) + self.assertAllClose(d_v, np.ones_like(input_value)) + + def testDynamicCalculation(self): + cell = Plus1RNNCell() + sequence_length = tf.placeholder(tf.int64) + batch_size = 2 + inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10 + with tf.variable_scope("drop_scope"): + dynamic_outputs, dynamic_states = rnn.rnn( + cell, inputs, sequence_length=sequence_length, dtype=tf.float32) + self.assertEqual(len(dynamic_outputs), len(inputs)) + self.assertEqual(len(dynamic_states), len(inputs)) + + with self.test_session(use_gpu=False) as sess: + input_value = np.random.randn(batch_size, 5) + dynamic_values = sess.run(dynamic_outputs, + feed_dict={inputs[0]: input_value, + sequence_length: [2, 3]}) + dynamic_state_values = sess.run(dynamic_states, + feed_dict={inputs[0]: input_value, + sequence_length: [2, 3]}) + + # fully calculated for t = 0, 1, 2 + for v in dynamic_values[:3]: + self.assertAllClose(v, input_value + 1.0) + for vi, v in enumerate(dynamic_state_values[:3]): + self.assertAllEqual(v, 1.0 * (vi + 1) * np.ones((batch_size, 5))) + # zeros for t = 3+ + for v in dynamic_values[3:]: + self.assertAllEqual(v, np.zeros_like(input_value)) + for v in dynamic_state_values[3:]: + self.assertAllEqual(v, np.zeros_like(input_value)) + + +class LSTMTest(tf.test.TestCase): + + def setUp(self): + self._seed = 23489 + np.random.seed(self._seed) + + def _testNoProjNoSharding(self, use_gpu): + num_units = 3 + input_size = 5 + batch_size = 2 + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) + cell = rnn_cell.LSTMCell( + num_units, input_size, initializer=initializer) + inputs = 10 * [ + tf.placeholder(tf.float32, shape=(batch_size, input_size))] + outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + self.assertEqual(len(outputs), len(inputs)) + for out in outputs: + self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) + + tf.initialize_all_variables().run() + input_value = np.random.randn(batch_size, input_size) + sess.run(outputs, feed_dict={inputs[0]: input_value}) + + def _testCellClipping(self, use_gpu): + num_units = 3 + input_size = 5 + batch_size = 2 + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) + cell = rnn_cell.LSTMCell( + num_units, input_size, use_peepholes=True, + cell_clip=0.0, initializer=initializer) + inputs = 10 * [ + tf.placeholder(tf.float32, shape=(batch_size, input_size))] + outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + self.assertEqual(len(outputs), len(inputs)) + for out in outputs: + self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) + + tf.initialize_all_variables().run() + input_value = np.random.randn(batch_size, input_size) + values = sess.run(outputs, feed_dict={inputs[0]: input_value}) + + for value in values: + # if cell c is clipped to 0, tanh(c) = 0 => m==0 + self.assertAllEqual(value, np.zeros((batch_size, num_units))) + + def _testNoProjNoShardingSimpleStateSaver(self, use_gpu): + num_units = 3 + input_size = 5 + batch_size = 2 + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) + state_saver = TestStateSaver(batch_size, 2*num_units) + cell = rnn_cell.LSTMCell( + num_units, input_size, use_peepholes=False, initializer=initializer) + inputs = 10 * [ + tf.placeholder(tf.float32, shape=(batch_size, input_size))] + with tf.variable_scope("share_scope"): + outputs, states = rnn.state_saving_rnn( + cell, inputs, state_saver=state_saver, state_name="save_lstm") + self.assertEqual(len(outputs), len(inputs)) + for out in outputs: + self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) + + tf.initialize_all_variables().run() + input_value = np.random.randn(batch_size, input_size) + (last_state_value, saved_state_value) = sess.run( + [states[-1], state_saver.saved_state], + feed_dict={inputs[0]: input_value}) + self.assertAllEqual(last_state_value, saved_state_value) + + def _testProjNoSharding(self, use_gpu): + num_units = 3 + input_size = 5 + batch_size = 2 + num_proj = 4 + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) + inputs = 10 * [ + tf.placeholder(tf.float32, shape=(None, input_size))] + cell = rnn_cell.LSTMCell( + num_units, input_size, use_peepholes=True, + num_proj=num_proj, initializer=initializer) + outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + self.assertEqual(len(outputs), len(inputs)) + + tf.initialize_all_variables().run() + input_value = np.random.randn(batch_size, input_size) + sess.run(outputs, feed_dict={inputs[0]: input_value}) + + def _testProjSharding(self, use_gpu): + num_units = 3 + input_size = 5 + batch_size = 2 + num_proj = 4 + num_proj_shards = 4 + num_unit_shards = 2 + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) + + inputs = 10 * [ + tf.placeholder(tf.float32, shape=(None, input_size))] + + cell = rnn_cell.LSTMCell( + num_units, + input_size=input_size, + use_peepholes=True, + num_proj=num_proj, + num_unit_shards=num_unit_shards, + num_proj_shards=num_proj_shards, + initializer=initializer) + + outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + + self.assertEqual(len(outputs), len(inputs)) + + tf.initialize_all_variables().run() + input_value = np.random.randn(batch_size, input_size) + sess.run(outputs, feed_dict={inputs[0]: input_value}) + + def _testDoubleInput(self, use_gpu): + num_units = 3 + input_size = 5 + batch_size = 2 + num_proj = 4 + num_proj_shards = 4 + num_unit_shards = 2 + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed) + inputs = 10 * [tf.placeholder(tf.float64)] + + cell = rnn_cell.LSTMCell( + num_units, + input_size=input_size, + use_peepholes=True, + num_proj=num_proj, + num_unit_shards=num_unit_shards, + num_proj_shards=num_proj_shards, + initializer=initializer) + + outputs, _ = rnn.rnn( + cell, inputs, initial_state=cell.zero_state(batch_size, tf.float64)) + + self.assertEqual(len(outputs), len(inputs)) + + tf.initialize_all_variables().run() + input_value = np.asarray(np.random.randn(batch_size, input_size), + dtype=np.float64) + values = sess.run(outputs, feed_dict={inputs[0]: input_value}) + self.assertEqual(values[0].dtype, input_value.dtype) + + def _testShardNoShardEquivalentOutput(self, use_gpu): + num_units = 3 + input_size = 5 + batch_size = 2 + num_proj = 4 + num_proj_shards = 4 + num_unit_shards = 2 + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + inputs = 10 * [tf.placeholder(tf.float32)] + initializer = tf.constant_initializer(0.001) + + cell_noshard = rnn_cell.LSTMCell( + num_units, input_size, + num_proj=num_proj, + use_peepholes=True, + initializer=initializer, + num_unit_shards=num_unit_shards, + num_proj_shards=num_proj_shards) + + cell_shard = rnn_cell.LSTMCell( + num_units, input_size, use_peepholes=True, + initializer=initializer, num_proj=num_proj) + + with tf.variable_scope("noshard_scope"): + outputs_noshard, states_noshard = rnn.rnn( + cell_noshard, inputs, dtype=tf.float32) + with tf.variable_scope("shard_scope"): + outputs_shard, states_shard = rnn.rnn( + cell_shard, inputs, dtype=tf.float32) + + self.assertEqual(len(outputs_noshard), len(inputs)) + self.assertEqual(len(outputs_noshard), len(outputs_shard)) + + tf.initialize_all_variables().run() + input_value = np.random.randn(batch_size, input_size) + feeds = dict((x, input_value) for x in inputs) + values_noshard = sess.run(outputs_noshard, feed_dict=feeds) + values_shard = sess.run(outputs_shard, feed_dict=feeds) + state_values_noshard = sess.run(states_noshard, feed_dict=feeds) + state_values_shard = sess.run(states_shard, feed_dict=feeds) + self.assertEqual(len(values_noshard), len(values_shard)) + self.assertEqual(len(state_values_noshard), len(state_values_shard)) + for (v_noshard, v_shard) in zip(values_noshard, values_shard): + self.assertAllClose(v_noshard, v_shard, atol=1e-3) + for (s_noshard, s_shard) in zip(state_values_noshard, state_values_shard): + self.assertAllClose(s_noshard, s_shard, atol=1e-3) + + def _testDoubleInputWithDropoutAndDynamicCalculation( + self, use_gpu): + """Smoke test for using LSTM with doubles, dropout, dynamic calculation.""" + + num_units = 3 + input_size = 5 + batch_size = 2 + num_proj = 4 + num_proj_shards = 4 + num_unit_shards = 2 + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + sequence_length = tf.placeholder(tf.int64) + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) + inputs = 10 * [tf.placeholder(tf.float64)] + + cell = rnn_cell.LSTMCell( + num_units, + input_size=input_size, + use_peepholes=True, + num_proj=num_proj, + num_unit_shards=num_unit_shards, + num_proj_shards=num_proj_shards, + initializer=initializer) + dropout_cell = rnn_cell.DropoutWrapper(cell, 0.5, seed=0) + + outputs, states = rnn.rnn( + dropout_cell, inputs, sequence_length=sequence_length, + initial_state=cell.zero_state(batch_size, tf.float64)) + + self.assertEqual(len(outputs), len(inputs)) + self.assertEqual(len(outputs), len(states)) + + tf.initialize_all_variables().run() + input_value = np.asarray(np.random.randn(batch_size, input_size), + dtype=np.float64) + values = sess.run(outputs, feed_dict={inputs[0]: input_value, + sequence_length: [2, 3]}) + state_values = sess.run(states, feed_dict={inputs[0]: input_value, + sequence_length: [2, 3]}) + self.assertEqual(values[0].dtype, input_value.dtype) + self.assertEqual(state_values[0].dtype, input_value.dtype) + + def testSharingWeightsWithReuse(self): + num_units = 3 + input_size = 5 + batch_size = 2 + num_proj = 4 + with self.test_session(graph=tf.Graph()) as sess: + initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed) + inputs = 10 * [ + tf.placeholder(tf.float32, shape=(None, input_size))] + cell = rnn_cell.LSTMCell( + num_units, input_size, use_peepholes=True, + num_proj=num_proj, initializer=initializer) + + with tf.variable_scope("share_scope"): + outputs0, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + with tf.variable_scope("share_scope", reuse=True): + outputs1, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + with tf.variable_scope("diff_scope"): + outputs2, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + + tf.initialize_all_variables().run() + input_value = np.random.randn(batch_size, input_size) + output_values = sess.run( + outputs0 + outputs1 + outputs2, feed_dict={inputs[0]: input_value}) + outputs0_values = output_values[:10] + outputs1_values = output_values[10:20] + outputs2_values = output_values[20:] + self.assertEqual(len(outputs0_values), len(outputs1_values)) + self.assertEqual(len(outputs0_values), len(outputs2_values)) + for o1, o2, o3 in zip(outputs0_values, outputs1_values, outputs2_values): + # Same weights used by both RNNs so outputs should be the same. + self.assertAllEqual(o1, o2) + # Different weights used so outputs should be different. + self.assertTrue(np.linalg.norm(o1-o3) > 1e-6) + + def testSharingWeightsWithDifferentNamescope(self): + num_units = 3 + input_size = 5 + batch_size = 2 + num_proj = 4 + with self.test_session(graph=tf.Graph()) as sess: + initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed) + inputs = 10 * [ + tf.placeholder(tf.float32, shape=(None, input_size))] + cell = rnn_cell.LSTMCell( + num_units, input_size, use_peepholes=True, + num_proj=num_proj, initializer=initializer) + + with tf.name_scope("scope0"): + with tf.variable_scope("share_scope"): + outputs0, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + with tf.name_scope("scope1"): + with tf.variable_scope("share_scope", reuse=True): + outputs1, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + + tf.initialize_all_variables().run() + input_value = np.random.randn(batch_size, input_size) + output_values = sess.run( + outputs0 + outputs1, feed_dict={inputs[0]: input_value}) + outputs0_values = output_values[:10] + outputs1_values = output_values[10:] + self.assertEqual(len(outputs0_values), len(outputs1_values)) + for out0, out1 in zip(outputs0_values, outputs1_values): + self.assertAllEqual(out0, out1) + + def testNoProjNoShardingSimpleStateSaver(self): + self._testNoProjNoShardingSimpleStateSaver(False) + self._testNoProjNoShardingSimpleStateSaver(True) + + def testNoProjNoSharding(self): + self._testNoProjNoSharding(False) + self._testNoProjNoSharding(True) + + def testCellClipping(self): + self._testCellClipping(False) + self._testCellClipping(True) + + def testProjNoSharding(self): + self._testProjNoSharding(False) + self._testProjNoSharding(True) + + def testProjSharding(self): + self._testProjSharding(False) + self._testProjSharding(True) + + def testShardNoShardEquivalentOutput(self): + self._testShardNoShardEquivalentOutput(False) + self._testShardNoShardEquivalentOutput(True) + + def testDoubleInput(self): + self._testDoubleInput(False) + self._testDoubleInput(True) + + def testDoubleInputWithDropoutAndDynamicCalculation(self): + self._testDoubleInputWithDropoutAndDynamicCalculation(False) + self._testDoubleInputWithDropoutAndDynamicCalculation(True) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/models/rnn/seq2seq.py b/tensorflow/models/rnn/seq2seq.py new file mode 100644 index 0000000000..a3b6a838ca --- /dev/null +++ b/tensorflow/models/rnn/seq2seq.py @@ -0,0 +1,749 @@ +"""Library for creating sequence-to-sequence models.""" + +import tensorflow.python.platform + +import tensorflow as tf + +from tensorflow.models.rnn import linear +from tensorflow.models.rnn import rnn +from tensorflow.models.rnn import rnn_cell + + +def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, + scope=None): + """RNN decoder for the sequence-to-sequence model. + + Args: + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + initial_state: 2D Tensor with shape [batch_size x cell.state_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + loop_function: if not None, this function will be applied to i-th output + in order to generate i+1-th input, and decoder_inputs will be ignored, + except for the first element ("GO" symbol). This can be used for decoding, + but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. + Signature -- loop_function(prev, i) = next + * prev is a 2D Tensor of shape [batch_size x cell.output_size], + * i is an integer, the step number (when advanced control is needed), + * next is a 2D Tensor of shape [batch_size x cell.input_size]. + scope: VariableScope for the created subgraph; defaults to "rnn_decoder". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x cell.output_size] containing generated outputs. + states: The state of each cell in each time-step. This is a list with + length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + (Note that in some cases, like basic RNN cell or GRU cell, outputs and + states can be the same. They are different for LSTM cells though.) + """ + with tf.variable_scope(scope or "rnn_decoder"): + states = [initial_state] + outputs = [] + prev = None + for i in xrange(len(decoder_inputs)): + inp = decoder_inputs[i] + if loop_function is not None and prev is not None: + with tf.variable_scope("loop_function", reuse=True): + # We do not propagate gradients over the loop function. + inp = tf.stop_gradient(loop_function(prev, i)) + if i > 0: + tf.get_variable_scope().reuse_variables() + output, new_state = cell(inp, states[-1]) + outputs.append(output) + states.append(new_state) + if loop_function is not None: + prev = tf.stop_gradient(output) + return outputs, states + + +def basic_rnn_seq2seq( + encoder_inputs, decoder_inputs, cell, dtype=tf.float32, scope=None): + """Basic RNN sequence-to-sequence model. + + This model first runs an RNN to encode encoder_inputs into a state vector, and + then runs decoder, initialized with the last encoder state, on decoder_inputs. + Encoder and decoder use the same RNN cell type, but don't share parameters. + + Args: + encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + dtype: The dtype of the initial state of the RNN cell (default: tf.float32). + scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x cell.output_size] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + """ + with tf.variable_scope(scope or "basic_rnn_seq2seq"): + _, enc_states = rnn.rnn(cell, encoder_inputs, dtype=dtype) + return rnn_decoder(decoder_inputs, enc_states[-1], cell) + + +def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, + loop_function=None, dtype=tf.float32, scope=None): + """RNN sequence-to-sequence model with tied encoder and decoder parameters. + + This model first runs an RNN to encode encoder_inputs into a state vector, and + then runs decoder, initialized with the last encoder state, on decoder_inputs. + Encoder and decoder use the same RNN cell and share parameters. + + Args: + encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + loop_function: if not None, this function will be applied to i-th output + in order to generate i+1-th input, and decoder_inputs will be ignored, + except for the first element ("GO" symbol), see rnn_decoder for details. + dtype: The dtype of the initial state of the rnn cell (default: tf.float32). + scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x cell.output_size] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + """ + with tf.variable_scope("combined_tied_rnn_seq2seq"): + scope = scope or "tied_rnn_seq2seq" + _, enc_states = rnn.rnn( + cell, encoder_inputs, dtype=dtype, scope=scope) + tf.get_variable_scope().reuse_variables() + return rnn_decoder(decoder_inputs, enc_states[-1], cell, + loop_function=loop_function, scope=scope) + + +def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols, + output_projection=None, feed_previous=False, + scope=None): + """RNN decoder with embedding and a pure-decoding option. + + Args: + decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs). + initial_state: 2D Tensor [batch_size x cell.state_size]. + cell: rnn_cell.RNNCell defining the cell function. + num_symbols: integer, how many symbols come into the embedding. + output_projection: None or a pair (W, B) of output projection weights and + biases; W has shape [cell.output_size x num_symbols] and B has + shape [num_symbols]; if provided and feed_previous=True, each fed + previous output will first be multiplied by W and added B. + feed_previous: Boolean; if True, only the first of decoder_inputs will be + used (the "GO" symbol), and all other decoder inputs will be generated by: + next = embedding_lookup(embedding, argmax(previous_output)), + In effect, this implements a greedy decoder. It can also be used + during training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. + If False, decoder_inputs are used as given (the standard decoder case). + scope: VariableScope for the created subgraph; defaults to + "embedding_rnn_decoder". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x cell.output_size] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + + Raises: + ValueError: when output_projection has the wrong shape. + """ + if output_projection is not None: + proj_weights = tf.convert_to_tensor(output_projection[0], dtype=tf.float32) + proj_weights.get_shape().assert_is_compatible_with([cell.output_size, + num_symbols]) + proj_biases = tf.convert_to_tensor(output_projection[1], dtype=tf.float32) + proj_biases.get_shape().assert_is_compatible_with([num_symbols]) + + with tf.variable_scope(scope or "embedding_rnn_decoder"): + with tf.device("/cpu:0"): + embedding = tf.get_variable("embedding", [num_symbols, cell.input_size]) + + def extract_argmax_and_embed(prev, _): + """Loop_function that extracts the symbol from prev and embeds it.""" + if output_projection is not None: + prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1]) + prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) + return tf.nn.embedding_lookup(embedding, prev_symbol) + + loop_function = None + if feed_previous: + loop_function = extract_argmax_and_embed + + emb_inp = [tf.nn.embedding_lookup(embedding, i) for i in decoder_inputs] + return rnn_decoder(emb_inp, initial_state, cell, + loop_function=loop_function) + + +def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, + num_encoder_symbols, num_decoder_symbols, + output_projection=None, feed_previous=False, + dtype=tf.float32, scope=None): + """Embedding RNN sequence-to-sequence model. + + This model first embeds encoder_inputs by a newly created embedding (of shape + [num_encoder_symbols x cell.input_size]). Then it runs an RNN to encode + embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs + by another newly created embedding (of shape [num_decoder_symbols x + cell.input_size]). Then it runs RNN decoder, initialized with the last + encoder state, on embedded decoder_inputs. + + Args: + encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + num_encoder_symbols: integer; number of symbols on the encoder side. + num_decoder_symbols: integer; number of symbols on the decoder side. + output_projection: None or a pair (W, B) of output projection weights and + biases; W has shape [cell.output_size x num_decoder_symbols] and B has + shape [num_decoder_symbols]; if provided and feed_previous=True, each + fed previous output will first be multiplied by W and added B. + feed_previous: Boolean or scalar Boolean Tensor; if True, only the first + of decoder_inputs will be used (the "GO" symbol), and all other decoder + inputs will be taken from previous outputs (as in embedding_rnn_decoder). + If False, decoder_inputs are used as given (the standard decoder case). + dtype: The dtype of the initial state for both the encoder and encoder + rnn cells (default: tf.float32). + scope: VariableScope for the created subgraph; defaults to + "embedding_rnn_seq2seq" + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x num_decoder_symbols] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + """ + with tf.variable_scope(scope or "embedding_rnn_seq2seq"): + # Encoder. + encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols) + _, encoder_states = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) + + # Decoder. + if output_projection is None: + cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) + + if isinstance(feed_previous, bool): + return embedding_rnn_decoder(decoder_inputs, encoder_states[-1], cell, + num_decoder_symbols, output_projection, + feed_previous) + else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. + outputs1, states1 = embedding_rnn_decoder( + decoder_inputs, encoder_states[-1], cell, num_decoder_symbols, + output_projection, True) + tf.get_variable_scope().reuse_variables() + outputs2, states2 = embedding_rnn_decoder( + decoder_inputs, encoder_states[-1], cell, num_decoder_symbols, + output_projection, False) + + outputs = tf.control_flow_ops.cond(feed_previous, + lambda: outputs1, lambda: outputs2) + states = tf.control_flow_ops.cond(feed_previous, + lambda: states1, lambda: states2) + return outputs, states + + +def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, + num_symbols, output_projection=None, + feed_previous=False, dtype=tf.float32, + scope=None): + """Embedding RNN sequence-to-sequence model with tied (shared) parameters. + + This model first embeds encoder_inputs by a newly created embedding (of shape + [num_symbols x cell.input_size]). Then it runs an RNN to encode embedded + encoder_inputs into a state vector. Next, it embeds decoder_inputs using + the same embedding. Then it runs RNN decoder, initialized with the last + encoder state, on embedded decoder_inputs. + + Args: + encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + num_symbols: integer; number of symbols for both encoder and decoder. + output_projection: None or a pair (W, B) of output projection weights and + biases; W has shape [cell.output_size x num_symbols] and B has + shape [num_symbols]; if provided and feed_previous=True, each + fed previous output will first be multiplied by W and added B. + feed_previous: Boolean or scalar Boolean Tensor; if True, only the first + of decoder_inputs will be used (the "GO" symbol), and all other decoder + inputs will be taken from previous outputs (as in embedding_rnn_decoder). + If False, decoder_inputs are used as given (the standard decoder case). + dtype: The dtype to use for the initial RNN states (default: tf.float32). + scope: VariableScope for the created subgraph; defaults to + "embedding_tied_rnn_seq2seq". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x num_decoder_symbols] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + + Raises: + ValueError: when output_projection has the wrong shape. + """ + if output_projection is not None: + proj_weights = tf.convert_to_tensor(output_projection[0], dtype=dtype) + proj_weights.get_shape().assert_is_compatible_with([cell.output_size, + num_symbols]) + proj_biases = tf.convert_to_tensor(output_projection[1], dtype=dtype) + proj_biases.get_shape().assert_is_compatible_with([num_symbols]) + + with tf.variable_scope(scope or "embedding_tied_rnn_seq2seq"): + with tf.device("/cpu:0"): + embedding = tf.get_variable("embedding", [num_symbols, cell.input_size]) + + emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x) + for x in encoder_inputs] + emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x) + for x in decoder_inputs] + + def extract_argmax_and_embed(prev, _): + """Loop_function that extracts the symbol from prev and embeds it.""" + if output_projection is not None: + prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1]) + prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) + return tf.nn.embedding_lookup(embedding, prev_symbol) + + if output_projection is None: + cell = rnn_cell.OutputProjectionWrapper(cell, num_symbols) + + if isinstance(feed_previous, bool): + loop_function = extract_argmax_and_embed if feed_previous else None + return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell, + loop_function=loop_function, dtype=dtype) + else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. + outputs1, states1 = tied_rnn_seq2seq( + emb_encoder_inputs, emb_decoder_inputs, cell, + loop_function=extract_argmax_and_embed, dtype=dtype) + tf.get_variable_scope().reuse_variables() + outputs2, states2 = tied_rnn_seq2seq( + emb_encoder_inputs, emb_decoder_inputs, cell, dtype=dtype) + + outputs = tf.control_flow_ops.cond(feed_previous, + lambda: outputs1, lambda: outputs2) + states = tf.control_flow_ops.cond(feed_previous, + lambda: states1, lambda: states2) + return outputs, states + + +def attention_decoder(decoder_inputs, initial_state, attention_states, cell, + output_size=None, num_heads=1, loop_function=None, + dtype=tf.float32, scope=None): + """RNN decoder with attention for the sequence-to-sequence model. + + Args: + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + initial_state: 2D Tensor [batch_size x cell.state_size]. + attention_states: 3D Tensor [batch_size x attn_length x attn_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + output_size: size of the output vectors; if None, we use cell.output_size. + num_heads: number of attention heads that read from attention_states. + loop_function: if not None, this function will be applied to i-th output + in order to generate i+1-th input, and decoder_inputs will be ignored, + except for the first element ("GO" symbol). This can be used for decoding, + but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. + Signature -- loop_function(prev, i) = next + * prev is a 2D Tensor of shape [batch_size x cell.output_size], + * i is an integer, the step number (when advanced control is needed), + * next is a 2D Tensor of shape [batch_size x cell.input_size]. + dtype: The dtype to use for the RNN initial state (default: tf.float32). + scope: VariableScope for the created subgraph; default: "attention_decoder". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors of shape + [batch_size x output_size]. These represent the generated outputs. + Output i is computed from input i (which is either i-th decoder_inputs or + loop_function(output {i-1}, i)) as follows. First, we run the cell + on a combination of the input and previous attention masks: + cell_output, new_state = cell(linear(input, prev_attn), prev_state). + Then, we calculate new attention masks: + new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) + and then we calculate the output: + output = linear(cell_output, new_attn). + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + + Raises: + ValueError: when num_heads is not positive, there are no inputs, or shapes + of attention_states are not set. + """ + if not decoder_inputs: + raise ValueError("Must provide at least 1 input to attention decoder.") + if num_heads < 1: + raise ValueError("With less than 1 heads, use a non-attention decoder.") + if not attention_states.get_shape()[1:2].is_fully_defined(): + raise ValueError("Shape[1] and [2] of attention_states must be known: %s" + % attention_states.get_shape()) + if output_size is None: + output_size = cell.output_size + + with tf.variable_scope(scope or "attention_decoder"): + batch_size = tf.shape(decoder_inputs[0])[0] # Needed for reshaping. + attn_length = attention_states.get_shape()[1].value + attn_size = attention_states.get_shape()[2].value + + # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. + hidden = tf.reshape(attention_states, [-1, attn_length, 1, attn_size]) + hidden_features = [] + v = [] + attention_vec_size = attn_size # Size of query vectors for attention. + for a in xrange(num_heads): + k = tf.get_variable("AttnW_%d" % a, [1, 1, attn_size, attention_vec_size]) + hidden_features.append(tf.nn.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) + v.append(tf.get_variable("AttnV_%d" % a, [attention_vec_size])) + + states = [initial_state] + + def attention(query): + """Put attention masks on hidden using hidden_features and query.""" + ds = [] # Results of attention reads will be stored here. + for a in xrange(num_heads): + with tf.variable_scope("Attention_%d" % a): + y = linear.linear(query, attention_vec_size, True) + y = tf.reshape(y, [-1, 1, 1, attention_vec_size]) + # Attention mask is a softmax of v^T * tanh(...). + s = tf.reduce_sum(v[a] * tf.tanh(hidden_features[a] + y), [2, 3]) + a = tf.nn.softmax(s) + # Now calculate the attention-weighted vector d. + d = tf.reduce_sum(tf.reshape(a, [-1, attn_length, 1, 1]) * hidden, + [1, 2]) + ds.append(tf.reshape(d, [-1, attn_size])) + return ds + + outputs = [] + prev = None + batch_attn_size = tf.pack([batch_size, attn_size]) + attns = [tf.zeros(batch_attn_size, dtype=dtype) + for _ in xrange(num_heads)] + for a in attns: # Ensure the second shape of attention vectors is set. + a.set_shape([None, attn_size]) + for i in xrange(len(decoder_inputs)): + if i > 0: + tf.get_variable_scope().reuse_variables() + inp = decoder_inputs[i] + # If loop_function is set, we use it instead of decoder_inputs. + if loop_function is not None and prev is not None: + with tf.variable_scope("loop_function", reuse=True): + inp = tf.stop_gradient(loop_function(prev, i)) + # Merge input and previous attentions into one vector of the right size. + x = linear.linear([inp] + attns, cell.input_size, True) + # Run the RNN. + cell_output, new_state = cell(x, states[-1]) + states.append(new_state) + # Run the attention mechanism. + attns = attention(new_state) + with tf.variable_scope("AttnOutputProjection"): + output = linear.linear([cell_output] + attns, output_size, True) + if loop_function is not None: + # We do not propagate gradients over the loop function. + prev = tf.stop_gradient(output) + outputs.append(output) + + return outputs, states + + +def embedding_attention_decoder(decoder_inputs, initial_state, attention_states, + cell, num_symbols, num_heads=1, + output_size=None, output_projection=None, + feed_previous=False, dtype=tf.float32, + scope=None): + """RNN decoder with embedding and attention and a pure-decoding option. + + Args: + decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs). + initial_state: 2D Tensor [batch_size x cell.state_size]. + attention_states: 3D Tensor [batch_size x attn_length x attn_size]. + cell: rnn_cell.RNNCell defining the cell function. + num_symbols: integer, how many symbols come into the embedding. + num_heads: number of attention heads that read from attention_states. + output_size: size of the output vectors; if None, use cell.output_size. + output_projection: None or a pair (W, B) of output projection weights and + biases; W has shape [output_size x num_symbols] and B has shape + [num_symbols]; if provided and feed_previous=True, each fed previous + output will first be multiplied by W and added B. + feed_previous: Boolean; if True, only the first of decoder_inputs will be + used (the "GO" symbol), and all other decoder inputs will be generated by: + next = embedding_lookup(embedding, argmax(previous_output)), + In effect, this implements a greedy decoder. It can also be used + during training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. + If False, decoder_inputs are used as given (the standard decoder case). + dtype: The dtype to use for the RNN initial states (default: tf.float32). + scope: VariableScope for the created subgraph; defaults to + "embedding_attention_decoder". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x output_size] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + + Raises: + ValueError: when output_projection has the wrong shape. + """ + if output_size is None: + output_size = cell.output_size + if output_projection is not None: + proj_weights = tf.convert_to_tensor(output_projection[0], dtype=dtype) + proj_weights.get_shape().assert_is_compatible_with([cell.output_size, + num_symbols]) + proj_biases = tf.convert_to_tensor(output_projection[1], dtype=dtype) + proj_biases.get_shape().assert_is_compatible_with([num_symbols]) + + with tf.variable_scope(scope or "embedding_attention_decoder"): + with tf.device("/cpu:0"): + embedding = tf.get_variable("embedding", [num_symbols, cell.input_size]) + + def extract_argmax_and_embed(prev, _): + """Loop_function that extracts the symbol from prev and embeds it.""" + if output_projection is not None: + prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1]) + prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) + emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol) + return emb_prev + + loop_function = None + if feed_previous: + loop_function = extract_argmax_and_embed + + emb_inp = [tf.nn.embedding_lookup(embedding, i) for i in decoder_inputs] + return attention_decoder( + emb_inp, initial_state, attention_states, cell, output_size=output_size, + num_heads=num_heads, loop_function=loop_function) + + +def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell, + num_encoder_symbols, num_decoder_symbols, + num_heads=1, output_projection=None, + feed_previous=False, dtype=tf.float32, + scope=None): + """Embedding sequence-to-sequence model with attention. + + This model first embeds encoder_inputs by a newly created embedding (of shape + [num_encoder_symbols x cell.input_size]). Then it runs an RNN to encode + embedded encoder_inputs into a state vector. It keeps the outputs of this + RNN at every step to use for attention later. Next, it embeds decoder_inputs + by another newly created embedding (of shape [num_decoder_symbols x + cell.input_size]). Then it runs attention decoder, initialized with the last + encoder state, on embedded decoder_inputs and attending to encoder outputs. + + Args: + encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + num_encoder_symbols: integer; number of symbols on the encoder side. + num_decoder_symbols: integer; number of symbols on the decoder side. + num_heads: number of attention heads that read from attention_states. + output_projection: None or a pair (W, B) of output projection weights and + biases; W has shape [cell.output_size x num_decoder_symbols] and B has + shape [num_decoder_symbols]; if provided and feed_previous=True, each + fed previous output will first be multiplied by W and added B. + feed_previous: Boolean or scalar Boolean Tensor; if True, only the first + of decoder_inputs will be used (the "GO" symbol), and all other decoder + inputs will be taken from previous outputs (as in embedding_rnn_decoder). + If False, decoder_inputs are used as given (the standard decoder case). + dtype: The dtype of the initial RNN state (default: tf.float32). + scope: VariableScope for the created subgraph; defaults to + "embedding_attention_seq2seq". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x num_decoder_symbols] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + """ + with tf.variable_scope(scope or "embedding_attention_seq2seq"): + # Encoder. + encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols) + encoder_outputs, encoder_states = rnn.rnn( + encoder_cell, encoder_inputs, dtype=dtype) + + # First calculate a concatenation of encoder outputs to put attention on. + top_states = [tf.reshape(e, [-1, 1, cell.output_size]) + for e in encoder_outputs] + attention_states = tf.concat(1, top_states) + + # Decoder. + output_size = None + if output_projection is None: + cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) + output_size = num_decoder_symbols + + if isinstance(feed_previous, bool): + return embedding_attention_decoder( + decoder_inputs, encoder_states[-1], attention_states, cell, + num_decoder_symbols, num_heads, output_size, output_projection, + feed_previous) + else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. + outputs1, states1 = embedding_attention_decoder( + decoder_inputs, encoder_states[-1], attention_states, cell, + num_decoder_symbols, num_heads, output_size, output_projection, True) + tf.get_variable_scope().reuse_variables() + outputs2, states2 = embedding_attention_decoder( + decoder_inputs, encoder_states[-1], attention_states, cell, + num_decoder_symbols, num_heads, output_size, output_projection, False) + + outputs = tf.control_flow_ops.cond(feed_previous, + lambda: outputs1, lambda: outputs2) + states = tf.control_flow_ops.cond(feed_previous, + lambda: states1, lambda: states2) + return outputs, states + + +def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols, + average_across_timesteps=True, + softmax_loss_function=None, name=None): + """Weighted cross-entropy loss for a sequence of logits (per example). + + Args: + logits: list of 2D Tensors of shape [batch_size x num_decoder_symbols]. + targets: list of 1D batch-sized int32-Tensors of the same length as logits. + weights: list of 1D batch-sized float-Tensors of the same length as logits. + num_decoder_symbols: integer, number of decoder symbols (output classes). + average_across_timesteps: If set, divide the returned cost by the total + label weight. + softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch + to be used instead of the standard softmax (the default if this is None). + name: optional name for this operation, default: "sequence_loss_by_example". + + Returns: + 1D batch-sized float Tensor: the log-perplexity for each sequence. + + Raises: + ValueError: if len(logits) is different from len(targets) or len(weights). + """ + if len(targets) != len(logits) or len(weights) != len(logits): + raise ValueError("Lengths of logits, weights, and targets must be the same " + "%d, %d, %d." % (len(logits), len(weights), len(targets))) + with tf.op_scope(logits + targets + weights, name, + "sequence_loss_by_example"): + batch_size = tf.shape(targets[0])[0] + log_perp_list = [] + length = batch_size * num_decoder_symbols + for i in xrange(len(logits)): + if softmax_loss_function is None: + # TODO(lukaszkaiser): There is no SparseCrossEntropy in TensorFlow, so + # we need to first cast targets into a dense representation, and as + # SparseToDense does not accept batched inputs, we need to do this by + # re-indexing and re-sizing. When TensorFlow adds SparseCrossEntropy, + # rewrite this method. + indices = targets[i] + num_decoder_symbols * tf.range(0, batch_size) + with tf.device("/cpu:0"): # Sparse-to-dense must happen on CPU for now. + dense = tf.sparse_to_dense(indices, tf.expand_dims(length, 0), 1.0, + 0.0) + target = tf.reshape(dense, [-1, num_decoder_symbols]) + crossent = tf.nn.softmax_cross_entropy_with_logits( + logits[i], target, name="SequenceLoss/CrossEntropy{0}".format(i)) + else: + crossent = softmax_loss_function(logits[i], targets[i]) + log_perp_list.append(crossent * weights[i]) + log_perps = tf.add_n(log_perp_list) + if average_across_timesteps: + total_size = tf.add_n(weights) + total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. + log_perps /= total_size + return log_perps + + +def sequence_loss(logits, targets, weights, num_decoder_symbols, + average_across_timesteps=True, average_across_batch=True, + softmax_loss_function=None, name=None): + """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. + + Args: + logits: list of 2D Tensors os shape [batch_size x num_decoder_symbols]. + targets: list of 1D batch-sized int32-Tensors of the same length as logits. + weights: list of 1D batch-sized float-Tensors of the same length as logits. + num_decoder_symbols: integer, number of decoder symbols (output classes). + average_across_timesteps: If set, divide the returned cost by the total + label weight. + average_across_batch: If set, divide the returned cost by the batch size. + softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch + to be used instead of the standard softmax (the default if this is None). + name: optional name for this operation, defaults to "sequence_loss". + + Returns: + A scalar float Tensor: the average log-perplexity per symbol (weighted). + + Raises: + ValueError: if len(logits) is different from len(targets) or len(weights). + """ + with tf.op_scope(logits + targets + weights, name, "sequence_loss"): + cost = tf.reduce_sum(sequence_loss_by_example( + logits, targets, weights, num_decoder_symbols, + average_across_timesteps=average_across_timesteps, + softmax_loss_function=softmax_loss_function)) + if average_across_batch: + batch_size = tf.shape(targets[0])[0] + return cost / tf.cast(batch_size, tf.float32) + else: + return cost + + +def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, + buckets, num_decoder_symbols, seq2seq, + softmax_loss_function=None, name=None): + """Create a sequence-to-sequence model with support for bucketing. + + The seq2seq argument is a function that defines a sequence-to-sequence model, + e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) + + Args: + encoder_inputs: a list of Tensors to feed the encoder; first seq2seq input. + decoder_inputs: a list of Tensors to feed the decoder; second seq2seq input. + targets: a list of 1D batch-sized int32-Tensors (desired output sequence). + weights: list of 1D batch-sized float-Tensors to weight the targets. + buckets: a list of pairs of (input size, output size) for each bucket. + num_decoder_symbols: integer, number of decoder symbols (output classes). + seq2seq: a sequence-to-sequence model function; it takes 2 input that + agree with encoder_inputs and decoder_inputs, and returns a pair + consisting of outputs and states (as, e.g., basic_rnn_seq2seq). + softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch + to be used instead of the standard softmax (the default if this is None). + name: optional name for this operation, defaults to "model_with_buckets". + + Returns: + outputs: The outputs for each bucket. Its j'th element consists of a list + of 2D Tensors of shape [batch_size x num_decoder_symbols] (j'th outputs). + losses: List of scalar Tensors, representing losses for each bucket. + Raises: + ValueError: if length of encoder_inputsut, targets, or weights is smaller + than the largest (last) bucket. + """ + if len(encoder_inputs) < buckets[-1][0]: + raise ValueError("Length of encoder_inputs (%d) must be at least that of la" + "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) + if len(targets) < buckets[-1][1]: + raise ValueError("Length of targets (%d) must be at least that of last" + "bucket (%d)." % (len(targets), buckets[-1][1])) + if len(weights) < buckets[-1][1]: + raise ValueError("Length of weights (%d) must be at least that of last" + "bucket (%d)." % (len(weights), buckets[-1][1])) + + all_inputs = encoder_inputs + decoder_inputs + targets + weights + losses = [] + outputs = [] + with tf.op_scope(all_inputs, name, "model_with_buckets"): + for j in xrange(len(buckets)): + if j > 0: + tf.get_variable_scope().reuse_variables() + bucket_encoder_inputs = [encoder_inputs[i] + for i in xrange(buckets[j][0])] + bucket_decoder_inputs = [decoder_inputs[i] + for i in xrange(buckets[j][1])] + bucket_outputs, _ = seq2seq(bucket_encoder_inputs, + bucket_decoder_inputs) + outputs.append(bucket_outputs) + + bucket_targets = [targets[i] for i in xrange(buckets[j][1])] + bucket_weights = [weights[i] for i in xrange(buckets[j][1])] + losses.append(sequence_loss( + outputs[-1], bucket_targets, bucket_weights, num_decoder_symbols, + softmax_loss_function=softmax_loss_function)) + + return outputs, losses diff --git a/tensorflow/models/rnn/seq2seq_test.py b/tensorflow/models/rnn/seq2seq_test.py new file mode 100644 index 0000000000..c5125acc21 --- /dev/null +++ b/tensorflow/models/rnn/seq2seq_test.py @@ -0,0 +1,384 @@ +"""Tests for functional style sequence-to-sequence models.""" +import math +import random + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.models.rnn import rnn +from tensorflow.models.rnn import rnn_cell +from tensorflow.models.rnn import seq2seq + + +class Seq2SeqTest(tf.test.TestCase): + + def testRNNDecoder(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] + _, enc_states = rnn.rnn(rnn_cell.GRUCell(2), inp, dtype=tf.float32) + dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] + cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) + dec, mem = seq2seq.rnn_decoder(dec_inp, enc_states[-1], cell) + sess.run([tf.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 4)) + + res = sess.run(mem) + self.assertEqual(len(res), 4) + self.assertEqual(res[0].shape, (2, 2)) + + def testBasicRNNSeq2Seq(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] + dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] + cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) + dec, mem = seq2seq.basic_rnn_seq2seq(inp, dec_inp, cell) + sess.run([tf.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 4)) + + res = sess.run(mem) + self.assertEqual(len(res), 4) + self.assertEqual(res[0].shape, (2, 2)) + + def testTiedRNNSeq2Seq(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] + dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] + cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) + dec, mem = seq2seq.tied_rnn_seq2seq(inp, dec_inp, cell) + sess.run([tf.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 4)) + + res = sess.run(mem) + self.assertEqual(len(res), 4) + self.assertEqual(res[0].shape, (2, 2)) + + def testEmbeddingRNNDecoder(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] + cell = rnn_cell.BasicLSTMCell(2) + _, enc_states = rnn.rnn(cell, inp, dtype=tf.float32) + dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] + dec, mem = seq2seq.embedding_rnn_decoder(dec_inp, enc_states[-1], + cell, 4) + sess.run([tf.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 2)) + + res = sess.run(mem) + self.assertEqual(len(res), 4) + self.assertEqual(res[0].shape, (2, 4)) + + def testEmbeddingRNNSeq2Seq(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)] + dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] + cell = rnn_cell.BasicLSTMCell(2) + dec, mem = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp, cell, 2, 5) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 5)) + + res = sess.run(mem) + self.assertEqual(len(res), 4) + self.assertEqual(res[0].shape, (2, 4)) + + # Test externally provided output projection. + w = tf.get_variable("proj_w", [2, 5]) + b = tf.get_variable("proj_b", [5]) + with tf.variable_scope("proj_seq2seq"): + dec, _ = seq2seq.embedding_rnn_seq2seq( + enc_inp, dec_inp, cell, 2, 5, output_projection=(w, b)) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 2)) + + # Test that previous-feeding model ignores inputs after the first. + dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)] + tf.get_variable_scope().reuse_variables() + d1, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp, cell, 2, 5, + feed_previous=True) + d2, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp2, cell, 2, 5, + feed_previous=True) + d3, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp2, cell, 2, 5, + feed_previous=tf.constant(True)) + res1 = sess.run(d1) + res2 = sess.run(d2) + res3 = sess.run(d3) + self.assertAllClose(res1, res2) + self.assertAllClose(res1, res3) + + def testEmbeddingTiedRNNSeq2Seq(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)] + dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] + cell = rnn_cell.BasicLSTMCell(2) + dec, mem = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp, cell, 5) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 5)) + + res = sess.run(mem) + self.assertEqual(len(res), 4) + self.assertEqual(res[0].shape, (2, 4)) + + # Test externally provided output projection. + w = tf.get_variable("proj_w", [2, 5]) + b = tf.get_variable("proj_b", [5]) + with tf.variable_scope("proj_seq2seq"): + dec, _ = seq2seq.embedding_tied_rnn_seq2seq( + enc_inp, dec_inp, cell, 5, output_projection=(w, b)) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 2)) + + # Test that previous-feeding model ignores inputs after the first. + dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)] + tf.get_variable_scope().reuse_variables() + d1, _ = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp, cell, 5, + feed_previous=True) + d2, _ = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp2, cell, 5, + feed_previous=True) + d3, _ = seq2seq.embedding_tied_rnn_seq2seq( + enc_inp, dec_inp2, cell, 5, feed_previous=tf.constant(True)) + res1 = sess.run(d1) + res2 = sess.run(d2) + res3 = sess.run(d3) + self.assertAllClose(res1, res2) + self.assertAllClose(res1, res3) + + def testAttentionDecoder1(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + cell = rnn_cell.GRUCell(2) + inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] + enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32) + attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size]) + for e in enc_outputs]) + dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] + dec, mem = seq2seq.attention_decoder(dec_inp, enc_states[-1], + attn_states, cell, output_size=4) + sess.run([tf.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 4)) + + res = sess.run(mem) + self.assertEqual(len(res), 4) + self.assertEqual(res[0].shape, (2, 2)) + + def testAttentionDecoder2(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + cell = rnn_cell.GRUCell(2) + inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] + enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32) + attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size]) + for e in enc_outputs]) + dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] + dec, mem = seq2seq.attention_decoder(dec_inp, enc_states[-1], + attn_states, cell, output_size=4, + num_heads=2) + sess.run([tf.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 4)) + + res = sess.run(mem) + self.assertEqual(len(res), 4) + self.assertEqual(res[0].shape, (2, 2)) + + def testEmbeddingAttentionDecoder(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] + cell = rnn_cell.GRUCell(2) + enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32) + attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size]) + for e in enc_outputs]) + dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] + dec, mem = seq2seq.embedding_attention_decoder(dec_inp, enc_states[-1], + attn_states, cell, 4, + output_size=3) + sess.run([tf.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 3)) + + res = sess.run(mem) + self.assertEqual(len(res), 4) + self.assertEqual(res[0].shape, (2, 2)) + + def testEmbeddingAttentionSeq2Seq(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)] + dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] + cell = rnn_cell.BasicLSTMCell(2) + dec, mem = seq2seq.embedding_attention_seq2seq( + enc_inp, dec_inp, cell, 2, 5) + sess.run([tf.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 5)) + + res = sess.run(mem) + self.assertEqual(len(res), 4) + self.assertEqual(res[0].shape, (2, 4)) + + # Test externally provided output projection. + w = tf.get_variable("proj_w", [2, 5]) + b = tf.get_variable("proj_b", [5]) + with tf.variable_scope("proj_seq2seq"): + dec, _ = seq2seq.embedding_attention_seq2seq( + enc_inp, dec_inp, cell, 2, 5, output_projection=(w, b)) + sess.run([tf.variables.initialize_all_variables()]) + res = sess.run(dec) + self.assertEqual(len(res), 3) + self.assertEqual(res[0].shape, (2, 2)) + + # Test that previous-feeding model ignores inputs after the first. + dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)] + tf.get_variable_scope().reuse_variables() + d1, _ = seq2seq.embedding_attention_seq2seq( + enc_inp, dec_inp, cell, 2, 5, feed_previous=True) + d2, _ = seq2seq.embedding_attention_seq2seq( + enc_inp, dec_inp2, cell, 2, 5, feed_previous=True) + d3, _ = seq2seq.embedding_attention_seq2seq( + enc_inp, dec_inp2, cell, 2, 5, feed_previous=tf.constant(True)) + res1 = sess.run(d1) + res2 = sess.run(d2) + res3 = sess.run(d3) + self.assertAllClose(res1, res2) + self.assertAllClose(res1, res3) + + def testSequenceLoss(self): + with self.test_session() as sess: + output_classes = 5 + logits = [tf.constant(i + 0.5, shape=[2, 5]) for i in xrange(3)] + targets = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] + weights = [tf.constant(1.0, shape=[2]) for i in xrange(3)] + + average_loss_per_example = seq2seq.sequence_loss( + logits, targets, weights, output_classes, + average_across_timesteps=True, + average_across_batch=True) + res = sess.run(average_loss_per_example) + self.assertAllClose(res, 1.60944) + + average_loss_per_sequence = seq2seq.sequence_loss( + logits, targets, weights, output_classes, + average_across_timesteps=False, + average_across_batch=True) + res = sess.run(average_loss_per_sequence) + self.assertAllClose(res, 4.828314) + + total_loss = seq2seq.sequence_loss( + logits, targets, weights, output_classes, + average_across_timesteps=False, + average_across_batch=False) + res = sess.run(total_loss) + self.assertAllClose(res, 9.656628) + + def testSequenceLossByExample(self): + with self.test_session() as sess: + output_classes = 5 + logits = [tf.constant(i + 0.5, shape=[2, output_classes]) + for i in xrange(3)] + targets = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] + weights = [tf.constant(1.0, shape=[2]) for i in xrange(3)] + + average_loss_per_example = seq2seq.sequence_loss_by_example( + logits, targets, weights, output_classes, + average_across_timesteps=True) + res = sess.run(average_loss_per_example) + self.assertAllClose(res, np.asarray([1.609438, 1.609438])) + + loss_per_sequence = seq2seq.sequence_loss_by_example( + logits, targets, weights, output_classes, + average_across_timesteps=False) + res = sess.run(loss_per_sequence) + self.assertAllClose(res, np.asarray([4.828314, 4.828314])) + + def testModelWithBuckets(self): + """Larger tests that does full sequence-to-sequence model training.""" + # We learn to copy 10 symbols in 2 buckets: length 4 and length 8. + classes = 10 + buckets = [(4, 4), (8, 8)] + # We use sampled softmax so we keep output projection separate. + w = tf.get_variable("proj_w", [24, classes]) + w_t = tf.transpose(w) + b = tf.get_variable("proj_b", [classes]) + # Here comes a sample Seq2Seq model using GRU cells. + def SampleGRUSeq2Seq(enc_inp, dec_inp, weights): + """Example sequence-to-sequence model that uses GRU cells.""" + def GRUSeq2Seq(enc_inp, dec_inp): + cell = rnn_cell.MultiRNNCell([rnn_cell.GRUCell(24)] * 2) + return seq2seq.embedding_attention_seq2seq( + enc_inp, dec_inp, cell, classes, classes, output_projection=(w, b)) + targets = [dec_inp[i+1] for i in xrange(len(dec_inp) - 1)] + [0] + def SampledLoss(inputs, labels): + labels = tf.reshape(labels, [-1, 1]) + return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, 8, classes) + return seq2seq.model_with_buckets(enc_inp, dec_inp, targets, weights, + buckets, classes, GRUSeq2Seq, + softmax_loss_function=SampledLoss) + # Now we construct the copy model. + with self.test_session() as sess: + tf.set_random_seed(111) + batch_size = 32 + inp = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)] + out = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)] + weights = [tf.ones_like(inp[0], dtype=tf.float32) for _ in xrange(8)] + with tf.variable_scope("root"): + _, losses = SampleGRUSeq2Seq(inp, out, weights) + updates = [] + params = tf.all_variables() + optimizer = tf.train.AdamOptimizer(0.03, epsilon=1e-5) + for i in xrange(len(buckets)): + full_grads = tf.gradients(losses[i], params) + grads, _ = tf.clip_by_global_norm(full_grads, 30.0) + update = optimizer.apply_gradients(zip(grads, params)) + updates.append(update) + sess.run([tf.initialize_all_variables()]) + for ep in xrange(3): + log_perp = 0.0 + for _ in xrange(50): + bucket = random.choice(range(len(buckets))) + length = buckets[bucket][0] + i = [np.array([np.random.randint(9) + 1 for _ in xrange(batch_size)], + dtype=np.int32) for _ in xrange(length)] + # 0 is our "GO" symbol here. + o = [np.array([0 for _ in xrange(batch_size)], dtype=np.int32)] + i + feed = {} + for l in xrange(length): + feed[inp[l].name] = i[l] + feed[out[l].name] = o[l] + if length < 8: # For the 4-bucket, we need the 5th as target. + feed[out[length].name] = o[length] + res = sess.run([updates[bucket], losses[bucket]], feed) + log_perp += float(res[1]) + perp = math.exp(log_perp / 100) + print "step %d avg. perp %f" % ((ep + 1)*50, perp) + self.assertLess(perp, 2.5) + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/models/rnn/translate/BUILD b/tensorflow/models/rnn/translate/BUILD new file mode 100644 index 0000000000..0899bf689e --- /dev/null +++ b/tensorflow/models/rnn/translate/BUILD @@ -0,0 +1,71 @@ +# Description: +# Example neural translation models. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "data_utils", + srcs = [ + "data_utils.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "seq2seq_model", + srcs = [ + "seq2seq_model.py", + ], + deps = [ + ":data_utils", + "//tensorflow:tensorflow_py", + "//tensorflow/models/rnn:seq2seq", + ], +) + +py_binary( + name = "translate", + srcs = [ + "translate.py", + ], + deps = [ + ":data_utils", + ":seq2seq_model", + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "translate_test", + size = "medium", + srcs = [ + "translate.py", + ], + args = [ + "--self_test=True", + ], + main = "translate.py", + deps = [ + ":data_utils", + ":seq2seq_model", + "//tensorflow:tensorflow_py", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/models/rnn/translate/__init__.py b/tensorflow/models/rnn/translate/__init__.py new file mode 100755 index 0000000000..e69de29bb2 diff --git a/tensorflow/models/rnn/translate/data_utils.py b/tensorflow/models/rnn/translate/data_utils.py new file mode 100644 index 0000000000..28bc54354c --- /dev/null +++ b/tensorflow/models/rnn/translate/data_utils.py @@ -0,0 +1,264 @@ +"""Utilities for downloading data from WMT, tokenizing, vocabularies.""" + +import gzip +import os +import re +import tarfile +import urllib + +from tensorflow.python.platform import gfile + +# Special vocabulary symbols - we always put them at the start. +_PAD = "_PAD" +_GO = "_GO" +_EOS = "_EOS" +_UNK = "_UNK" +_START_VOCAB = [_PAD, _GO, _EOS, _UNK] + +PAD_ID = 0 +GO_ID = 1 +EOS_ID = 2 +UNK_ID = 3 + +# Regular expressions used to tokenize. +_WORD_SPLIT = re.compile("([.,!?\"':;)(])") +_DIGIT_RE = re.compile(r"\d") + +# URLs for WMT data. +_WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar" +_WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz" + + +def maybe_download(directory, filename, url): + """Download filename from url unless it's already in directory.""" + if not os.path.exists(directory): + print "Creating directory %s" % directory + os.mkdir(directory) + filepath = os.path.join(directory, filename) + if not os.path.exists(filepath): + print "Downloading %s to %s" % (url, filepath) + filepath, _ = urllib.urlretrieve(url, filepath) + statinfo = os.stat(filepath) + print "Succesfully downloaded", filename, statinfo.st_size, "bytes" + return filepath + + +def gunzip_file(gz_path, new_path): + """Unzips from gz_path into new_path.""" + print "Unpacking %s to %s" % (gz_path, new_path) + with gzip.open(gz_path, "rb") as gz_file: + with open(new_path, "w") as new_file: + for line in gz_file: + new_file.write(line) + + +def get_wmt_enfr_train_set(directory): + """Download the WMT en-fr training corpus to directory unless it's there.""" + train_path = os.path.join(directory, "giga-fren.release2") + if not (gfile.Exists(train_path +".fr") and gfile.Exists(train_path +".en")): + corpus_file = maybe_download(directory, "training-giga-fren.tar", + _WMT_ENFR_TRAIN_URL) + print "Extracting tar file %s" % corpus_file + with tarfile.open(corpus_file, "r") as corpus_tar: + corpus_tar.extractall(directory) + gunzip_file(train_path + ".fr.gz", train_path + ".fr") + gunzip_file(train_path + ".en.gz", train_path + ".en") + return train_path + + +def get_wmt_enfr_dev_set(directory): + """Download the WMT en-fr training corpus to directory unless it's there.""" + dev_name = "newstest2013" + dev_path = os.path.join(directory, dev_name) + if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")): + dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL) + print "Extracting tgz file %s" % dev_file + with tarfile.open(dev_file, "r:gz") as dev_tar: + fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr") + en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en") + fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix. + en_dev_file.name = dev_name + ".en" + dev_tar.extract(fr_dev_file, directory) + dev_tar.extract(en_dev_file, directory) + return dev_path + + +def basic_tokenizer(sentence): + """Very basic tokenizer: split the sentence into a list of tokens.""" + words = [] + for space_separated_fragment in sentence.strip().split(): + words.extend(re.split(_WORD_SPLIT, space_separated_fragment)) + return [w for w in words if w] + + +def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, + tokenizer=None, normalize_digits=True): + """Create vocabulary file (if it does not exist yet) from data file. + + Data file is assumed to contain one sentence per line. Each sentence is + tokenized and digits are normalized (if normalize_digits is set). + Vocabulary contains the most-frequent tokens up to max_vocabulary_size. + We write it to vocabulary_path in a one-token-per-line format, so that later + token in the first line gets id=0, second line gets id=1, and so on. + + Args: + vocabulary_path: path where the vocabulary will be created. + data_path: data file that will be used to create vocabulary. + max_vocabulary_size: limit on the size of the created vocabulary. + tokenizer: a function to use to tokenize each data sentence; + if None, basic_tokenizer will be used. + normalize_digits: Boolean; if true, all digits are replaced by 0s. + """ + if not gfile.Exists(vocabulary_path): + print "Creating vocabulary %s from data %s" % (vocabulary_path, data_path) + vocab = {} + with gfile.GFile(data_path, mode="r") as f: + counter = 0 + for line in f: + counter += 1 + if counter % 100000 == 0: print " processing line %d" % counter + tokens = tokenizer(line) if tokenizer else basic_tokenizer(line) + for w in tokens: + word = re.sub(_DIGIT_RE, "0", w) if normalize_digits else w + if word in vocab: + vocab[word] += 1 + else: + vocab[word] = 1 + vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) + if len(vocab_list) > max_vocabulary_size: + vocab_list = vocab_list[:max_vocabulary_size] + with gfile.GFile(vocabulary_path, mode="w") as vocab_file: + for w in vocab_list: + vocab_file.write(w + "\n") + + +def initialize_vocabulary(vocabulary_path): + """Initialize vocabulary from file. + + We assume the vocabulary is stored one-item-per-line, so a file: + dog + cat + will result in a vocabulary {"dog": 0, "cat": 1}, and this function will + also return the reversed-vocabulary ["dog", "cat"]. + + Args: + vocabulary_path: path to the file containing the vocabulary. + + Returns: + a pair: the vocabulary (a dictionary mapping string to integers), and + the reversed vocabulary (a list, which reverses the vocabulary mapping). + + Raises: + ValueError: if the provided vocabulary_path does not exist. + """ + if gfile.Exists(vocabulary_path): + rev_vocab = [] + with gfile.GFile(vocabulary_path, mode="r") as f: + rev_vocab.extend(f.readlines()) + rev_vocab = [line.strip() for line in rev_vocab] + vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) + return vocab, rev_vocab + else: + raise ValueError("Vocabulary file %s not found.", vocabulary_path) + + +def sentence_to_token_ids(sentence, vocabulary, + tokenizer=None, normalize_digits=True): + """Convert a string to list of integers representing token-ids. + + For example, a sentence "I have a dog" may become tokenized into + ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2, + "a": 4, "dog": 7"} this function will return [1, 2, 4, 7]. + + Args: + sentence: a string, the sentence to convert to token-ids. + vocabulary: a dictionary mapping tokens to integers. + tokenizer: a function to use to tokenize each sentence; + if None, basic_tokenizer will be used. + normalize_digits: Boolean; if true, all digits are replaced by 0s. + + Returns: + a list of integers, the token-ids for the sentence. + """ + if tokenizer: + words = tokenizer(sentence) + else: + words = basic_tokenizer(sentence) + if not normalize_digits: + return [vocabulary.get(w, UNK_ID) for w in words] + # Normalize digits by 0 before looking words up in the vocabulary. + return [vocabulary.get(re.sub(_DIGIT_RE, "0", w), UNK_ID) for w in words] + + +def data_to_token_ids(data_path, target_path, vocabulary_path, + tokenizer=None, normalize_digits=True): + """Tokenize data file and turn into token-ids using given vocabulary file. + + This function loads data line-by-line from data_path, calls the above + sentence_to_token_ids, and saves the result to target_path. See comment + for sentence_to_token_ids on the details of token-ids format. + + Args: + data_path: path to the data file in one-sentence-per-line format. + target_path: path where the file with token-ids will be created. + vocabulary_path: path to the vocabulary file. + tokenizer: a function to use to tokenize each sentence; + if None, basic_tokenizer will be used. + normalize_digits: Boolean; if true, all digits are replaced by 0s. + """ + if not gfile.Exists(target_path): + print "Tokenizing data in %s" % data_path + vocab, _ = initialize_vocabulary(vocabulary_path) + with gfile.GFile(data_path, mode="r") as data_file: + with gfile.GFile(target_path, mode="w") as tokens_file: + counter = 0 + for line in data_file: + counter += 1 + if counter % 100000 == 0: print " tokenizing line %d" % counter + token_ids = sentence_to_token_ids(line, vocab, tokenizer, + normalize_digits) + tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") + + +def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size): + """Get WMT data into data_dir, create vocabularies and tokenize data. + + Args: + data_dir: directory in which the data sets will be stored. + en_vocabulary_size: size of the English vocabulary to create and use. + fr_vocabulary_size: size of the French vocabulary to create and use. + + Returns: + A tuple of 6 elements: + (1) path to the token-ids for English training data-set, + (2) path to the token-ids for French training data-set, + (3) path to the token-ids for English development data-set, + (4) path to the token-ids for French development data-set, + (5) path to the English vocabulary file, + (6) path to the French vocabluary file. + """ + # Get wmt data to the specified directory. + train_path = get_wmt_enfr_train_set(data_dir) + dev_path = get_wmt_enfr_dev_set(data_dir) + + # Create vocabularies of the appropriate sizes. + fr_vocab_path = os.path.join(data_dir, "vocab%d.fr" % fr_vocabulary_size) + en_vocab_path = os.path.join(data_dir, "vocab%d.en" % en_vocabulary_size) + create_vocabulary(fr_vocab_path, train_path + ".fr", fr_vocabulary_size) + create_vocabulary(en_vocab_path, train_path + ".en", en_vocabulary_size) + + # Create token ids for the training data. + fr_train_ids_path = train_path + (".ids%d.fr" % fr_vocabulary_size) + en_train_ids_path = train_path + (".ids%d.en" % en_vocabulary_size) + data_to_token_ids(train_path + ".fr", fr_train_ids_path, fr_vocab_path) + data_to_token_ids(train_path + ".en", fr_train_ids_path, fr_vocab_path) + + # Create token ids for the development data. + fr_dev_ids_path = dev_path + (".ids%d.fr" % fr_vocabulary_size) + en_dev_ids_path = dev_path + (".ids%d.en" % en_vocabulary_size) + data_to_token_ids(dev_path + ".fr", fr_dev_ids_path, fr_vocab_path) + data_to_token_ids(dev_path + ".en", en_dev_ids_path, en_vocab_path) + + return (en_train_ids_path, fr_train_ids_path, + en_dev_ids_path, fr_dev_ids_path, + en_vocab_path, fr_vocab_path) diff --git a/tensorflow/models/rnn/translate/seq2seq_model.py b/tensorflow/models/rnn/translate/seq2seq_model.py new file mode 100644 index 0000000000..3c9cfb007f --- /dev/null +++ b/tensorflow/models/rnn/translate/seq2seq_model.py @@ -0,0 +1,268 @@ +"""Sequence-to-sequence model with an attention mechanism.""" + +import random + +import numpy as np +import tensorflow as tf + +from tensorflow.models.rnn import rnn_cell +from tensorflow.models.rnn import seq2seq + +from tensorflow.models.rnn.translate import data_utils + + +class Seq2SeqModel(object): + """Sequence-to-sequence model with attention and for multiple buckets. + + This class implements a multi-layer recurrent neural network as encoder, + and an attention-based decoder. This is the same as the model described in + this paper: http://arxiv.org/abs/1412.7449 - please look there for details, + or into the seq2seq library for complete model implementation. + This class also allows to use GRU cells in addition to LSTM cells, and + sampled softmax to handle large output vocabulary size. A single-layer + version of this model, but with bi-directional encoder, was presented in + http://arxiv.org/abs/1409.0473 + and sampled softmax is described in Section 3 of the following paper. + http://arxiv.org/pdf/1412.2007v2.pdf + """ + + def __init__(self, source_vocab_size, target_vocab_size, buckets, size, + num_layers, max_gradient_norm, batch_size, learning_rate, + learning_rate_decay_factor, use_lstm=False, + num_samples=512, forward_only=False): + """Create the model. + + Args: + source_vocab_size: size of the source vocabulary. + target_vocab_size: size of the target vocabulary. + buckets: a list of pairs (I, O), where I specifies maximum input length + that will be processed in that bucket, and O specifies maximum output + length. Training instances that have inputs longer than I or outputs + longer than O will be pushed to the next bucket and padded accordingly. + We assume that the list is sorted, e.g., [(2, 4), (8, 16)]. + size: number of units in each layer of the model. + num_layers: number of layers in the model. + max_gradient_norm: gradients will be clipped to maximally this norm. + batch_size: the size of the batches used during training; + the model construction is independent of batch_size, so it can be + changed after initialization if this is convenient, e.g., for decoding. + learning_rate: learning rate to start with. + learning_rate_decay_factor: decay learning rate by this much when needed. + use_lstm: if true, we use LSTM cells instead of GRU cells. + num_samples: number of samples for sampled softmax. + forward_only: if set, we do not construct the backward pass in the model. + """ + self.source_vocab_size = source_vocab_size + self.target_vocab_size = target_vocab_size + self.buckets = buckets + self.batch_size = batch_size + self.learning_rate = tf.Variable(float(learning_rate), trainable=False) + self.learning_rate_decay_op = self.learning_rate.assign( + self.learning_rate * learning_rate_decay_factor) + self.global_step = tf.Variable(0, trainable=False) + + # If we use sampled softmax, we need an output projection. + output_projection = None + softmax_loss_function = None + # Sampled softmax only makes sense if we sample less than vocabulary size. + if num_samples > 0 and num_samples < self.target_vocab_size: + with tf.device("/cpu:0"): + w = tf.get_variable("proj_w", [size, self.target_vocab_size]) + w_t = tf.transpose(w) + b = tf.get_variable("proj_b", [self.target_vocab_size]) + output_projection = (w, b) + + def sampled_loss(inputs, labels): + with tf.device("/cpu:0"): + labels = tf.reshape(labels, [-1, 1]) + return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples, + self.target_vocab_size) + softmax_loss_function = sampled_loss + + # Create the internal multi-layer cell for our RNN. + single_cell = rnn_cell.GRUCell(size) + if use_lstm: + single_cell = rnn_cell.BasicLSTMCell(size) + cell = single_cell + if num_layers > 1: + cell = rnn_cell.MultiRNNCell([single_cell] * num_layers) + + # The seq2seq function: we use embedding for the input and attention. + def seq2seq_f(encoder_inputs, decoder_inputs, do_decode): + return seq2seq.embedding_attention_seq2seq( + encoder_inputs, decoder_inputs, cell, source_vocab_size, + target_vocab_size, output_projection=output_projection, + feed_previous=do_decode) + + # Feeds for inputs. + self.encoder_inputs = [] + self.decoder_inputs = [] + self.target_weights = [] + for i in xrange(buckets[-1][0]): # Last bucket is the biggest one. + self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], + name="encoder{0}".format(i))) + for i in xrange(buckets[-1][1] + 1): + self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], + name="decoder{0}".format(i))) + self.target_weights.append(tf.placeholder(tf.float32, shape=[None], + name="weight{0}".format(i))) + + # Our targets are decoder inputs shifted by one. + targets = [self.decoder_inputs[i + 1] + for i in xrange(len(self.decoder_inputs) - 1)] + + # Training outputs and losses. + if forward_only: + self.outputs, self.losses = seq2seq.model_with_buckets( + self.encoder_inputs, self.decoder_inputs, targets, + self.target_weights, buckets, self.target_vocab_size, + lambda x, y: seq2seq_f(x, y, True), + softmax_loss_function=softmax_loss_function) + # If we use output projection, we need to project outputs for decoding. + if output_projection is not None: + for b in xrange(len(buckets)): + self.outputs[b] = [tf.nn.xw_plus_b(output, output_projection[0], + output_projection[1]) + for output in self.outputs[b]] + else: + self.outputs, self.losses = seq2seq.model_with_buckets( + self.encoder_inputs, self.decoder_inputs, targets, + self.target_weights, buckets, self.target_vocab_size, + lambda x, y: seq2seq_f(x, y, False), + softmax_loss_function=softmax_loss_function) + + # Gradients and SGD update operation for training the model. + params = tf.trainable_variables() + if not forward_only: + self.gradient_norms = [] + self.updates = [] + opt = tf.train.GradientDescentOptimizer(self.learning_rate) + for b in xrange(len(buckets)): + gradients = tf.gradients(self.losses[b], params) + clipped_gradients, norm = tf.clip_by_global_norm(gradients, + max_gradient_norm) + self.gradient_norms.append(norm) + self.updates.append(opt.apply_gradients( + zip(clipped_gradients, params), global_step=self.global_step)) + + self.saver = tf.train.Saver(tf.all_variables()) + + def step(self, session, encoder_inputs, decoder_inputs, target_weights, + bucket_id, forward_only): + """Run a step of the model feeding the given inputs. + + Args: + session: tensorflow session to use. + encoder_inputs: list of numpy int vectors to feed as encoder inputs. + decoder_inputs: list of numpy int vectors to feed as decoder inputs. + target_weights: list of numpy float vectors to feed as target weights. + bucket_id: which bucket of the model to use. + forward_only: whether to do the backward step or only forward. + + Returns: + A triple consisting of gradient norm (or None if we did not do backward), + average perplexity, and the outputs. + + Raises: + ValueError: if length of enconder_inputs, decoder_inputs, or + target_weights disagrees with bucket size for the specified bucket_id. + """ + # Check if the sizes match. + encoder_size, decoder_size = self.buckets[bucket_id] + if len(encoder_inputs) != encoder_size: + raise ValueError("Encoder length must be equal to the one in bucket," + " %d != %d." % (len(encoder_inputs), encoder_size)) + if len(decoder_inputs) != decoder_size: + raise ValueError("Decoder length must be equal to the one in bucket," + " %d != %d." % (len(decoder_inputs), decoder_size)) + if len(target_weights) != decoder_size: + raise ValueError("Weights length must be equal to the one in bucket," + " %d != %d." % (len(target_weights), decoder_size)) + + # Input feed: encoder inputs, decoder inputs, target_weights, as provided. + input_feed = {} + for l in xrange(encoder_size): + input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] + for l in xrange(decoder_size): + input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] + input_feed[self.target_weights[l].name] = target_weights[l] + + # Since our targets are decoder inputs shifted by one, we need one more. + last_target = self.decoder_inputs[decoder_size].name + input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) + + # Output feed: depends on whether we do a backward step or not. + if not forward_only: + output_feed = [self.updates[bucket_id], # Update Op that does SGD. + self.gradient_norms[bucket_id], # Gradient norm. + self.losses[bucket_id]] # Loss for this batch. + else: + output_feed = [self.losses[bucket_id]] # Loss for this batch. + for l in xrange(decoder_size): # Output logits. + output_feed.append(self.outputs[bucket_id][l]) + + outputs = session.run(output_feed, input_feed) + if not forward_only: + return outputs[1], outputs[2], None # Gradient norm, loss, no outputs. + else: + return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs. + + def get_batch(self, data, bucket_id): + """Get a random batch of data from the specified bucket, prepare for step. + + To feed data in step(..) it must be a list of batch-major vectors, while + data here contains single length-major cases. So the main logic of this + function is to re-index data cases to be in the proper format for feeding. + + Args: + data: a tuple of size len(self.buckets) in which each element contains + lists of pairs of input and output data that we use to create a batch. + bucket_id: integer, which bucket to get the batch for. + + Returns: + The triple (encoder_inputs, decoder_inputs, target_weights) for + the constructed batch that has the proper format to call step(...) later. + """ + encoder_size, decoder_size = self.buckets[bucket_id] + encoder_inputs, decoder_inputs = [], [] + + # Get a random batch of encoder and decoder inputs from data, + # pad them if needed, reverse encoder inputs and add GO to decoder. + for _ in xrange(self.batch_size): + encoder_input, decoder_input = random.choice(data[bucket_id]) + + # Encoder inputs are padded and then reversed. + encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input)) + encoder_inputs.append(list(reversed(encoder_input + encoder_pad))) + + # Decoder inputs get an extra "GO" symbol, and are padded then. + decoder_pad_size = decoder_size - len(decoder_input) - 1 + decoder_inputs.append([data_utils.GO_ID] + decoder_input + + [data_utils.PAD_ID] * decoder_pad_size) + + # Now we create batch-major vectors from the data selected above. + batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], [] + + # Batch encoder inputs are just re-indexed encoder_inputs. + for length_idx in xrange(encoder_size): + batch_encoder_inputs.append( + np.array([encoder_inputs[batch_idx][length_idx] + for batch_idx in xrange(self.batch_size)], dtype=np.int32)) + + # Batch decoder inputs are re-indexed decoder_inputs, we create weights. + for length_idx in xrange(decoder_size): + batch_decoder_inputs.append( + np.array([decoder_inputs[batch_idx][length_idx] + for batch_idx in xrange(self.batch_size)], dtype=np.int32)) + + # Create target_weights to be 0 for targets that are padding. + batch_weight = np.ones(self.batch_size, dtype=np.float32) + for batch_idx in xrange(self.batch_size): + # We set weight to 0 if the corresponding target is a PAD symbol. + # The corresponding target is decoder_input shifted by 1 forward. + if length_idx < decoder_size - 1: + target = decoder_inputs[batch_idx][length_idx + 1] + if length_idx == decoder_size - 1 or target == data_utils.PAD_ID: + batch_weight[batch_idx] = 0.0 + batch_weights.append(batch_weight) + return batch_encoder_inputs, batch_decoder_inputs, batch_weights diff --git a/tensorflow/models/rnn/translate/translate.py b/tensorflow/models/rnn/translate/translate.py new file mode 100644 index 0000000000..abf4c7c57b --- /dev/null +++ b/tensorflow/models/rnn/translate/translate.py @@ -0,0 +1,260 @@ +"""Binary for training translation models and decoding from them. + +Running this program without --decode will download the WMT corpus into +the directory specified as --data_dir and tokenize it in a very basic way, +and then start training a model saving checkpoints to --train_dir. + +Running with --decode starts an interactive loop so you can see how +the current checkpoint translates English sentences into French. + +See the following papers for more information on neural translation models. + * http://arxiv.org/abs/1409.3215 + * http://arxiv.org/abs/1409.0473 + * http://arxiv.org/pdf/1412.2007v2.pdf +""" + +import math +import os +import random +import sys +import time + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.models.rnn.translate import data_utils +from tensorflow.models.rnn.translate import seq2seq_model +from tensorflow.python.platform import gfile + + +tf.app.flags.DEFINE_float("learning_rate", 0.5, "Learning rate.") +tf.app.flags.DEFINE_float("learning_rate_decay_factor", 0.99, + "Learning rate decays by this much.") +tf.app.flags.DEFINE_float("max_gradient_norm", 5.0, + "Clip gradients to this norm.") +tf.app.flags.DEFINE_integer("batch_size", 64, + "Batch size to use during training.") +tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.") +tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.") +tf.app.flags.DEFINE_integer("en_vocab_size", 40000, "English vocabulary size.") +tf.app.flags.DEFINE_integer("fr_vocab_size", 40000, "French vocabulary size.") +tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory") +tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.") +tf.app.flags.DEFINE_integer("max_train_data_size", 0, + "Limit on the size of training data (0: no limit).") +tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200, + "How many training steps to do per checkpoint.") +tf.app.flags.DEFINE_boolean("decode", False, + "Set to True for interactive decoding.") +tf.app.flags.DEFINE_boolean("self_test", False, + "Run a self-test if this is set to True.") + +FLAGS = tf.app.flags.FLAGS + +# We use a number of buckets and pad to the closest one for efficiency. +# See seq2seq_model.Seq2SeqModel for details of how they work. +_buckets = [(5, 10), (10, 15), (20, 25), (40, 50)] + + +def read_data(source_path, target_path, max_size=None): + """Read data from source and target files and put into buckets. + + Args: + source_path: path to the files with token-ids for the source language. + target_path: path to the file with token-ids for the target language; + it must be aligned with the source file: n-th line contains the desired + output for n-th line from the source_path. + max_size: maximum number of lines to read, all other will be ignored; + if 0 or None, data files will be read completely (no limit). + + Returns: + data_set: a list of length len(_buckets); data_set[n] contains a list of + (source, target) pairs read from the provided data files that fit + into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and + len(target) < _buckets[n][1]; source and target are lists of token-ids. + """ + data_set = [[] for _ in _buckets] + with gfile.GFile(source_path, mode="r") as source_file: + with gfile.GFile(target_path, mode="r") as target_file: + source, target = source_file.readline(), target_file.readline() + counter = 0 + while source and target and (not max_size or counter < max_size): + counter += 1 + if counter % 100000 == 0: + print " reading data line %d" % counter + sys.stdout.flush() + source_ids = [int(x) for x in source.split()] + target_ids = [int(x) for x in target.split()] + target_ids.append(data_utils.EOS_ID) + for bucket_id, (source_size, target_size) in enumerate(_buckets): + if len(source_ids) < source_size and len(target_ids) < target_size: + data_set[bucket_id].append([source_ids, target_ids]) + break + source, target = source_file.readline(), target_file.readline() + return data_set + + +def create_model(session, forward_only): + """Create translation model and initialize or load parameters in session.""" + model = seq2seq_model.Seq2SeqModel( + FLAGS.en_vocab_size, FLAGS.fr_vocab_size, _buckets, + FLAGS.size, FLAGS.num_layers, FLAGS.max_gradient_norm, FLAGS.batch_size, + FLAGS.learning_rate, FLAGS.learning_rate_decay_factor, + forward_only=forward_only) + ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) + if ckpt and gfile.Exists(ckpt.model_checkpoint_path): + print "Reading model parameters from %s" % ckpt.model_checkpoint_path + model.saver.restore(session, ckpt.model_checkpoint_path) + else: + print "Created model with fresh parameters." + session.run(tf.variables.initialize_all_variables()) + return model + + +def train(): + """Train a en->fr translation model using WMT data.""" + # Prepare WMT data. + print "Preparing WMT data in %s" % FLAGS.data_dir + en_train, fr_train, en_dev, fr_dev, _, _ = data_utils.prepare_wmt_data( + FLAGS.data_dir, FLAGS.en_vocab_size, FLAGS.fr_vocab_size) + + with tf.Session() as sess: + # Create model. + print "Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size) + model = create_model(sess, False) + + # Read data into buckets and compute their sizes. + print ("Reading development and training data (limit: %d)." + % FLAGS.max_train_data_size) + dev_set = read_data(en_dev, fr_dev) + train_set = read_data(en_train, fr_train, FLAGS.max_train_data_size) + train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))] + train_total_size = float(sum(train_bucket_sizes)) + + # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use + # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to + # the size if i-th training bucket, as used later. + train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size + for i in xrange(len(train_bucket_sizes))] + + # This is the training loop. + step_time, loss = 0.0, 0.0 + current_step = 0 + previous_losses = [] + while True: + # Choose a bucket according to data distribution. We pick a random number + # in [0, 1] and use the corresponding interval in train_buckets_scale. + random_number_01 = np.random.random_sample() + bucket_id = min([i for i in xrange(len(train_buckets_scale)) + if train_buckets_scale[i] > random_number_01]) + + # Get a batch and make a step. + start_time = time.time() + encoder_inputs, decoder_inputs, target_weights = model.get_batch( + train_set, bucket_id) + _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, + target_weights, bucket_id, False) + step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint + loss += step_loss / FLAGS.steps_per_checkpoint + current_step += 1 + + # Once in a while, we save checkpoint, print statistics, and run evals. + if current_step % FLAGS.steps_per_checkpoint == 0: + # Print statistics for the previous epoch. + perplexity = math.exp(loss) if loss < 300 else float('inf') + print ("global step %d learning rate %.4f step-time %.2f perplexity " + "%.2f" % (model.global_step.eval(), model.learning_rate.eval(), + step_time, perplexity)) + # Decrease learning rate if no improvement was seen over last 3 times. + if len(previous_losses) > 2 and loss > max(previous_losses[-3:]): + sess.run(model.learning_rate_decay_op) + previous_losses.append(loss) + # Save checkpoint and zero timer and loss. + checkpoint_path = os.path.join(FLAGS.train_dir, "translate.ckpt") + model.saver.save(sess, checkpoint_path, global_step=model.global_step) + step_time, loss = 0.0, 0.0 + # Run evals on development set and print their perplexity. + for bucket_id in xrange(len(_buckets)): + encoder_inputs, decoder_inputs, target_weights = model.get_batch( + dev_set, bucket_id) + _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, + target_weights, bucket_id, True) + eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf') + print " eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx) + sys.stdout.flush() + + +def decode(): + with tf.Session() as sess: + # Create model and load parameters. + model = create_model(sess, True) + model.batch_size = 1 # We decode one sentence at a time. + + # Load vocabularies. + en_vocab_path = os.path.join(FLAGS.data_dir, + "vocab%d.en" % FLAGS.en_vocab_size) + fr_vocab_path = os.path.join(FLAGS.data_dir, + "vocab%d.fr" % FLAGS.fr_vocab_size) + en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path) + _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path) + + # Decode from standard input. + sys.stdout.write("> ") + sys.stdout.flush() + sentence = sys.stdin.readline() + while sentence: + # Get token-ids for the input sentence. + token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab) + # Which bucket does it belong to? + bucket_id = min([b for b in xrange(len(_buckets)) + if _buckets[b][0] > len(token_ids)]) + # Get a 1-element batch to feed the sentence to the model. + encoder_inputs, decoder_inputs, target_weights = model.get_batch( + {bucket_id: [(token_ids, [])]}, bucket_id) + # Get output logits for the sentence. + _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, + target_weights, bucket_id, True) + # This is a greedy decoder - outputs are just argmaxes of output_logits. + outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] + # If there is an EOS symbol in outputs, cut them at that point. + if data_utils.EOS_ID in outputs: + outputs = outputs[:outputs.index(data_utils.EOS_ID)] + # Print out French sentence corresponding to outputs. + print " ".join([rev_fr_vocab[output] for output in outputs]) + print "> ", + sys.stdout.flush() + sentence = sys.stdin.readline() + + +def self_test(): + """Test the translation model.""" + with tf.Session() as sess: + print "Self-test for neural translation model." + # Create model with vocabularies of 10, 2 small buckets, 2 layers of 32. + model = seq2seq_model.Seq2SeqModel(10, 10, [(3, 3), (6, 6)], 32, 2, + 5.0, 32, 0.3, 0.99, num_samples=8) + sess.run(tf.variables.initialize_all_variables()) + + # Fake data set for both the (3, 3) and (6, 6) bucket. + data_set = ([([1, 1], [2, 2]), ([3, 3], [4]), ([5], [6])], + [([1, 1, 1, 1, 1], [2, 2, 2, 2, 2]), ([3, 3, 3], [5, 6])]) + for _ in xrange(5): # Train the fake model for 5 steps. + bucket_id = random.choice([0, 1]) + encoder_inputs, decoder_inputs, target_weights = model.get_batch( + data_set, bucket_id) + model.step(sess, encoder_inputs, decoder_inputs, target_weights, + bucket_id, False) + + +def main(_): + if FLAGS.self_test: + self_test() + elif FLAGS.decode: + decode() + else: + train() + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py index fcf269c717..99d02a5380 100644 --- a/tensorflow/python/framework/docs.py +++ b/tensorflow/python/framework/docs.py @@ -233,6 +233,14 @@ class Library(Document): # signatures. continue args_list.append(arg) + + # TODO(mrry): This is a workaround for documenting signature of + # functions that have the @contextlib.contextmanager decorator. + # We should do something better. + if argspec.varargs == "args" and argspec.keywords == "kwds": + original_func = func.func_closure[0].cell_contents + return self._generate_signature_for_function(original_func) + if argspec.defaults: for arg, default in zip( argspec.args[first_arg_with_default:], argspec.defaults): diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index e82c45d95f..8256acc514 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -76,7 +76,8 @@ def all_libraries(module_to_name, members, documented): "xw_plus_b", "relu_layer", "lrn", "batch_norm_with_global_normalization", "batch_norm_with_global_normalization_grad", - "all_candidate_sampler"], + "all_candidate_sampler", + "embedding_lookup_sparse"], prefix=PREFIX_TEXT), library('client', "Running Graphs", client_lib, exclude_symbols=["InteractiveSession"]), diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index bc64593d23..d9a377bab5 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -8,24 +8,31 @@ from tensorflow.python.ops import math_ops def embedding_lookup(params, ids, name=None): - """Return a tensor of embedding values by looking up "ids" in "params". + """Looks up `ids` in a list of embedding tensors. + + This function is used to perform parallel lookups on the list of + tensors in `params`. It is a generalization of + [`tf.gather()`](array_ops.md#gather), where `params` is interpreted + as a partition of a larger embedding tensor. + + If `len(params) > 1`, each element `id` of `ids` is partitioned between + the elements of `params` by computing `p = id % len(params)`, and is + then used to look up the slice `params[p][id // len(params), ...]`. + + The results of the lookup are then concatenated into a dense + tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. Args: - params: List of tensors of the same shape. A single tensor is - treated as a singleton list. - ids: Tensor of integers containing the ids to be looked up in - 'params'. Let P be len(params). If P > 1, then the ids are - partitioned by id % P, and we do separate lookups in params[p] - for 0 <= p < P, and then stitch the results back together into - a single result tensor. - name: Optional name for the op. + params: A list of tensors with the same shape and type. + ids: A `Tensor` with type `int32` containing the ids to be looked + up in `params`. + name: A name for the operation (optional). Returns: - A tensor of shape ids.shape + params[0].shape[1:] containing the - values params[i % P][i] for each i in ids. + A `Tensor` with the same type as the tensors in `params`. Raises: - ValueError: if some parameters are invalid. + ValueError: If `params` is empty. """ if not isinstance(params, list): params = [params] diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 8eccc63eb1..004ceb51c6 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -43,10 +43,10 @@ are as follows. If the 4-D `input` has shape `[batch, in_height, in_width, ...]` and the 4-D `filter` has shape `[filter_height, filter_width, ...]`, then - output.shape = [batch, - (in_height - filter_height + 1) / strides[1], - (in_width - filter_width + 1) / strides[2], - ...] + shape(output) = [batch, + (in_height - filter_height + 1) / strides[1], + (in_width - filter_width + 1) / strides[2], + ...] output[b, i, j, :] = sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] * @@ -58,7 +58,7 @@ vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]` is multiplied by a vector `filter[di, dj, k]`, and all the vectors are concatenated. -In the formula for `output.shape`, the rounding direction depends on padding: +In the formula for `shape(output)`, the rounding direction depends on padding: * `padding = 'SAME'`: Round down (only full size windows are considered). * `padding = 'VALID'`: Round up (partial windows are included). @@ -81,7 +81,7 @@ In detail, the output is for each tuple of indices `i`. The output shape is - output.shape = (value.shape - ksize + 1) / strides + shape(output) = (shape(value) - ksize + 1) / strides where the rounding direction depends on padding: @@ -119,10 +119,10 @@ TensorFlow provides several operations that help you perform classification. ## Embeddings -TensorFlow provides several operations that help you compute embeddings. +TensorFlow provides library support for looking up values in embedding +tensors. @@embedding_lookup -@@embedding_lookup_sparse ## Evaluation @@ -336,15 +336,16 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): By default, each element is kept or dropped independently. If `noise_shape` is specified, it must be [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - to the shape of `x`, and only dimensions with `noise_shape[i] == x.shape[i]` - will make independent decisions. For example, if `x.shape = [b, x, y, c]` and - `noise_shape = [b, 1, 1, c]`, each batch and channel component will be + to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]` + will make independent decisions. For example, if `shape(x) = [k, l, m, n]` + and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be kept independently and each row and column will be kept or not kept together. Args: x: A tensor. - keep_prob: Float probability that each element is kept. - noise_shape: Shape for randomly generated keep/drop flags. + keep_prob: A Python float. The probability that each element is kept. + noise_shape: A 1-D `Tensor` of type `int32`, representing the + shape for randomly generated keep/drop flags. seed: A Python integer. Used to create a random seed. See [`set_random_seed`](constant_op.md#set_random_seed) for behavior. name: A name for this operation (optional). diff --git a/tensorflow/tools/docker/Dockerfile.cpu b/tensorflow/tools/docker/Dockerfile.cpu index c93a6e8bd2..da0a7abb18 100644 --- a/tensorflow/tools/docker/Dockerfile.cpu +++ b/tensorflow/tools/docker/Dockerfile.cpu @@ -66,3 +66,7 @@ RUN bazel clean && \ bazel build -c opt tensorflow/tools/docker:simple_console ENV PYTHONPATH=/tensorflow/bazel-bin/tensorflow/tools/docker/simple_console.runfiles/:$PYTHONPATH + +# We want to start Jupyter in the directory with our getting started +# tutorials. +WORKDIR /notebooks -- cgit v1.2.3