aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-11-06 21:57:38 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-11-06 21:57:38 -0800
commit8bd3b38e662a1298bebcada676c7cc6e2ea49c0f (patch)
tree394e314734a9deb744800843367f4107c64e84fe
parentcd9e60c1cd8afef6e39b4b73525d64aee33b656b (diff)
TensorFlow: Upstream changes to git.
Changes: - Update a lot of documentation, installation instructions, requirements, etc. - Add RNN models directory for recurrent neural network examples to go along with the tutorials. Base CL: 107290480
-rw-r--r--CONTRIBUTING.md11
-rw-r--r--README.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/framework.md10
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md1
-rw-r--r--tensorflow/g3doc/api_docs/python/nn.md120
-rw-r--r--tensorflow/g3doc/api_docs/python/state_ops.md4
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md39
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/index.md7
-rw-r--r--tensorflow/g3doc/how_tos/variables/index.md2
-rw-r--r--tensorflow/g3doc/resources/faq.md28
-rw-r--r--tensorflow/g3doc/resources/uses.md4
-rw-r--r--tensorflow/g3doc/tutorials/deep_cnn/index.md4
-rw-r--r--tensorflow/g3doc/tutorials/index.md4
-rwxr-xr-xtensorflow/g3doc/tutorials/mandelbrot/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/mnist/beginners/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/mnist/download/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/mnist/pros/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/mnist/tf/index.md2
-rwxr-xr-xtensorflow/g3doc/tutorials/pdes/index.md1
-rw-r--r--tensorflow/g3doc/tutorials/seq2seq/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/word2vec/index.md2
-rw-r--r--tensorflow/models/rnn/BUILD106
-rw-r--r--tensorflow/models/rnn/README.md21
-rwxr-xr-xtensorflow/models/rnn/__init__.py0
-rw-r--r--tensorflow/models/rnn/linear.py49
-rw-r--r--tensorflow/models/rnn/linear_test.py35
-rw-r--r--tensorflow/models/rnn/ptb/BUILD49
-rwxr-xr-xtensorflow/models/rnn/ptb/__init__.py0
-rw-r--r--tensorflow/models/rnn/ptb/ptb_word_lm.py292
-rw-r--r--tensorflow/models/rnn/ptb/reader.py105
-rw-r--r--tensorflow/models/rnn/ptb/reader_test.py47
-rw-r--r--tensorflow/models/rnn/rnn.py128
-rw-r--r--tensorflow/models/rnn/rnn_cell.py605
-rw-r--r--tensorflow/models/rnn/rnn_cell_test.py154
-rw-r--r--tensorflow/models/rnn/rnn_test.py472
-rw-r--r--tensorflow/models/rnn/seq2seq.py749
-rw-r--r--tensorflow/models/rnn/seq2seq_test.py384
-rw-r--r--tensorflow/models/rnn/translate/BUILD71
-rwxr-xr-xtensorflow/models/rnn/translate/__init__.py0
-rw-r--r--tensorflow/models/rnn/translate/data_utils.py264
-rw-r--r--tensorflow/models/rnn/translate/seq2seq_model.py268
-rw-r--r--tensorflow/models/rnn/translate/translate.py260
-rw-r--r--tensorflow/python/framework/docs.py8
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py3
-rw-r--r--tensorflow/python/ops/embedding_ops.py31
-rw-r--r--tensorflow/python/ops/nn.py27
-rw-r--r--tensorflow/tools/docker/Dockerfile.cpu4
47 files changed, 4223 insertions, 160 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index dbaf281844..98339fe54a 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -15,3 +15,14 @@ Follow either of the two links above to access the appropriate CLA and instructi
***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the main repository.
+## Contributing code
+
+We currently use Gerrit to host and handle code changes to TensorFlow. The main
+site is
+[https://tensorflow-review.googlesource.com/](https://tensorflow-review.googlesource.com/).
+See Gerrit [docs](https://gerrit-review.googlesource.com/Documentation/) for
+information on how Gerrit's code review system works.
+
+We are currently working on improving our external acceptance process, so
+please be patient with us as we work out the details.
+
diff --git a/README.md b/README.md
index 2f8dac4a62..05bfec8431 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@ variety of other domains, as well.
# Download and Setup
For detailed installation instructions, see
-[here](g3doc/get_started/os_setup.md).
+[here](tensorflow/g3doc/get_started/os_setup.md).
## Binary Installation
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index 0614c68e15..1fc659ef0b 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -27,7 +27,7 @@
* [class tf.RegisterShape](#RegisterShape)
* [class tf.TensorShape](#TensorShape)
* [class tf.Dimension](#Dimension)
- * [tf.op_scope(*args, **kwds)](#op_scope)
+ * [tf.op_scope(values, name, default_name)](#op_scope)
* [tf.get_seed(op_seed)](#get_seed)
@@ -235,7 +235,7 @@ def my_func(pred, tensor):
- - -
-#### tf.Graph.device(*args, **kwds) {#Graph.device}
+#### tf.Graph.device(device_name_or_function) {#Graph.device}
Returns a context manager that specifies the default device to use.
@@ -287,7 +287,7 @@ with g.device(matmul_on_gpu):
- - -
-#### tf.Graph.name_scope(*args, **kwds) {#Graph.name_scope}
+#### tf.Graph.name_scope(name) {#Graph.name_scope}
Returns a context manager that creates hierarchical names for operations.
@@ -611,7 +611,7 @@ the default graph.
- - -
-#### tf.Graph.gradient_override_map(*args, **kwds) {#Graph.gradient_override_map}
+#### tf.Graph.gradient_override_map(op_type_map) {#Graph.gradient_override_map}
EXPERIMENTAL: A context manager for overriding gradient functions.
@@ -2023,7 +2023,7 @@ The value of this dimension, or None if it is unknown.
- - -
-### tf.op_scope(*args, **kwds) <div class="md-anchor" id="op_scope">{#op_scope}</div>
+### tf.op_scope(values, name, default_name) <div class="md-anchor" id="op_scope">{#op_scope}</div>
Returns a context manager for use when defining a Python op.
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index 8750d0aadd..dd47b703fc 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -267,7 +267,6 @@
* [`depthwise_conv2d`](nn.md#depthwise_conv2d)
* [`dropout`](nn.md#dropout)
* [`embedding_lookup`](nn.md#embedding_lookup)
- * [`embedding_lookup_sparse`](nn.md#embedding_lookup_sparse)
* [`fixed_unigram_candidate_sampler`](nn.md#fixed_unigram_candidate_sampler)
* [`in_top_k`](nn.md#in_top_k)
* [`l2_loss`](nn.md#l2_loss)
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index 50c460b68c..b129506107 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -35,7 +35,6 @@ accepted by [`tf.convert_to_tensor`](framework.md#convert_to_tensor).
* [tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)](#softmax_cross_entropy_with_logits)
* [Embeddings](#AUTOGENERATED-embeddings)
* [tf.nn.embedding_lookup(params, ids, name=None)](#embedding_lookup)
- * [tf.nn.embedding_lookup_sparse(params, sp_ids, sp_weights, name=None, combiner='mean')](#embedding_lookup_sparse)
* [Evaluation](#AUTOGENERATED-evaluation)
* [tf.nn.top_k(input, k, name=None)](#top_k)
* [tf.nn.in_top_k(predictions, targets, k, name=None)](#in_top_k)
@@ -130,17 +129,18 @@ sum is unchanged.
By default, each element is kept or dropped independently. If `noise_shape`
is specified, it must be
[broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-to the shape of `x`, and only dimensions with `noise_shape[i] == x.shape[i]`
-will make independent decisions. For example, if `x.shape = [b, x, y, c]` and
-`noise_shape = [b, 1, 1, c]`, each batch and channel component will be
+to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
+will make independent decisions. For example, if `shape(x) = [k, l, m, n]`
+and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be
kept independently and each row and column will be kept or not kept together.
##### Args:
* <b>x</b>: A tensor.
-* <b>keep_prob</b>: Float probability that each element is kept.
-* <b>noise_shape</b>: Shape for randomly generated keep/drop flags.
+* <b>keep_prob</b>: A Python float. The probability that each element is kept.
+* <b>noise_shape</b>: A 1-D `Tensor` of type `int32`, representing the
+ shape for randomly generated keep/drop flags.
* <b>seed</b>: A Python integer. Used to create a random seed.
See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
* <b>name</b>: A name for this operation (optional).
@@ -247,10 +247,10 @@ are as follows. If the 4-D `input` has shape
`[batch, in_height, in_width, ...]` and the 4-D `filter` has shape
`[filter_height, filter_width, ...]`, then
- output.shape = [batch,
- (in_height - filter_height + 1) / strides[1],
- (in_width - filter_width + 1) / strides[2],
- ...]
+ shape(output) = [batch,
+ (in_height - filter_height + 1) / strides[1],
+ (in_width - filter_width + 1) / strides[2],
+ ...]
output[b, i, j, :] =
sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] *
@@ -262,7 +262,7 @@ vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]`
is multiplied by a vector `filter[di, dj, k]`, and all the vectors are
concatenated.
-In the formula for `output.shape`, the rounding direction depends on padding:
+In the formula for `shape(output)`, the rounding direction depends on padding:
* `padding = 'SAME'`: Round down (only full size windows are considered).
* `padding = 'VALID'`: Round up (partial windows are included).
@@ -411,7 +411,7 @@ In detail, the output is
for each tuple of indices `i`. The output shape is
- output.shape = (value.shape - ksize + 1) / strides
+ shape(output) = (shape(value) - ksize + 1) / strides
where the rounding direction depends on padding:
@@ -722,103 +722,43 @@ and the same dtype (either `float32` or `float64`).
## Embeddings <div class="md-anchor" id="AUTOGENERATED-embeddings">{#AUTOGENERATED-embeddings}</div>
-TensorFlow provides several operations that help you compute embeddings.
+TensorFlow provides library support for looking up values in embedding
+tensors.
- - -
### tf.nn.embedding_lookup(params, ids, name=None) <div class="md-anchor" id="embedding_lookup">{#embedding_lookup}</div>
-Return a tensor of embedding values by looking up "ids" in "params".
+Looks up `ids` in a list of embedding tensors.
-##### Args:
-
-
-* <b>params</b>: List of tensors of the same shape. A single tensor is
- treated as a singleton list.
-* <b>ids</b>: Tensor of integers containing the ids to be looked up in
- 'params'. Let P be len(params). If P > 1, then the ids are
- partitioned by id % P, and we do separate lookups in params[p]
- for 0 <= p < P, and then stitch the results back together into
- a single result tensor.
-* <b>name</b>: Optional name for the op.
-
-##### Returns:
-
- A tensor of shape ids.shape + params[0].shape[1:] containing the
- values params[i % P][i] for each i in ids.
-
-##### Raises:
-
-
-* <b>ValueError</b>: if some parameters are invalid.
+This function is used to perform parallel lookups on the list of
+tensors in `params`. It is a generalization of
+[`tf.gather()`](array_ops.md#gather), where `params` is interpreted
+as a partition of a larger embedding tensor.
+If `len(params) > 1`, each element `id` of `ids` is partitioned between
+the elements of `params` by computing `p = id % len(params)`, and is
+then used to look up the slice `params[p][id // len(params), ...]`.
-- - -
-
-### tf.nn.embedding_lookup_sparse(params, sp_ids, sp_weights, name=None, combiner='mean') <div class="md-anchor" id="embedding_lookup_sparse">{#embedding_lookup_sparse}</div>
-
-Computes embeddings for the given ids and weights.
-
-This op assumes that there is at least one id for each row in the dense tensor
-represented by sp_ids (i.e. there are no rows with empty features), and that
-all the indices of sp_ids are in canonical row-major order.
-
-It also assumes that all id values lie in the range [0, p0), where p0
-is the sum of the size of params along dimension 0.
+The results of the lookup are then concatenated into a dense
+tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
##### Args:
-* <b>params</b>: A single tensor representing the complete embedding tensor,
- or a list of P tensors all of same shape except for the first dimension,
- representing sharded embedding tensors. In the latter case, the ids are
- partitioned by id % P, and we do separate lookups in params[p] for
- 0 <= p < P, and then stitch the results back together into a single
- result tensor. The first dimension is allowed to vary as the vocab
- size is not necessarily a multiple of P.
-* <b>sp_ids</b>: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
- where N is typically batch size and M is arbitrary.
-* <b>sp_weights</b>: either a SparseTensor of float / double weights, or None to
- indicate all weights should be taken to be 1. If specified, sp_weights
- must have exactly the same shape and indices as sp_ids.
-* <b>name</b>: Optional name for the op.
-* <b>combiner</b>: A string specifying the reduction op. Currently "mean" and "sum"
- are supported.
- "sum" computes the weighted sum of the embedding results for each row.
- "mean" is the weighted sum divided by the total weight.
+* <b>params</b>: A list of tensors with the same shape and type.
+* <b>ids</b>: A `Tensor` with type `int32` containing the ids to be looked
+ up in `params`.
+* <b>name</b>: A name for the operation (optional).
##### Returns:
- A dense tensor representing the combined embeddings for the
- sparse ids. For each row in the dense tensor represented by sp_ids, the op
- looks up the embeddings for all ids in that row, multiplies them by the
- corresponding weight, and combines these embeddings as specified.
-
- In other words, if
- shape(combined params) = [p0, p1, ..., pm]
- and
- shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]
- then
- shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].
-
- For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
-
- [0, 0]: id 1, weight 2.0
- [0, 1]: id 3, weight 0.5
- [1, 0]: id 0, weight 1.0
- [2, 3]: id 1, weight 3.0
-
- with combiner="mean", then the output will be a 3x20 matrix where
- output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
- output[1, :] = params[0, :] * 1.0
- output[2, :] = params[1, :] * 3.0
+ A `Tensor` with the same type as the tensors in `params`.
##### Raises:
-* <b>TypeError</b>: If sp_ids is not a SparseTensor, or if sp_weights is neither
- None nor SparseTensor.
-* <b>ValueError</b>: If combiner is not one of {"mean", "sum"}.
+* <b>ValueError</b>: If `params` is empty.
diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md
index 25bbe55f7b..70685a65bc 100644
--- a/tensorflow/g3doc/api_docs/python/state_ops.md
+++ b/tensorflow/g3doc/api_docs/python/state_ops.md
@@ -23,7 +23,7 @@ accepted by [`tf.convert_to_tensor`](framework.md#convert_to_tensor).
* [Sharing Variables](#AUTOGENERATED-sharing-variables)
* [tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None, trainable=True, collections=None)](#get_variable)
* [tf.get_variable_scope()](#get_variable_scope)
- * [tf.variable_scope(*args, **kwds)](#variable_scope)
+ * [tf.variable_scope(name_or_scope, reuse=None, initializer=None)](#variable_scope)
* [tf.constant_initializer(value=0.0)](#constant_initializer)
* [tf.random_normal_initializer(mean=0.0, stddev=1.0, seed=None)](#random_normal_initializer)
* [tf.truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None)](#truncated_normal_initializer)
@@ -896,7 +896,7 @@ Returns the current variable scope.
- - -
-### tf.variable_scope(*args, **kwds) <div class="md-anchor" id="variable_scope">{#variable_scope}</div>
+### tf.variable_scope(name_or_scope, reuse=None, initializer=None) <div class="md-anchor" id="variable_scope">{#variable_scope}</div>
Returns a context for variable scope.
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index f8113bcaec..0917f16832 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -36,7 +36,33 @@ Install TensorFlow (only CPU binary version is currently available).
$ sudo pip install https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl
```
-### Try your first TensorFlow program
+## Docker-based installation
+
+We also support running TensorFlow via [Docker](http://docker.com/), which lets
+you avoid worrying about setting up dependencies.
+
+First, [install Docker](http://docs.docker.com/engine/installation/). Once
+Docker is up and running, you can start a container with one command:
+
+```sh
+$ docker run -it b.gcr.io/tensorflow/tensorflow
+```
+
+This will start a container with TensorFlow and all its dependencies already
+installed.
+
+### Additional images
+
+The default Docker image above contains just a minimal set of libraries for
+getting up and running with TensorFlow. We also have several other containers,
+which you can use in the `docker run` command above:
+
+* `b.gcr.io/tensorflow/tensorflow-full`: Contains a complete TensorFlow source
+ installation, including all utilities needed to build and run TensorFlow. This
+ makes it easy to experiment directly with the source, without needing to
+ install any of the dependencies described above.
+
+## Try your first TensorFlow program
```sh
$ python
@@ -133,6 +159,13 @@ $ sudo apt-get install python-numpy swig python-dev
In order to build TensorFlow with GPU support, both Cuda Toolkit 7.0 and CUDNN
6.5 V2 from NVIDIA need to be installed.
+TensorFlow GPU support requires having a GPU card with NVidia Compute Capability >= 3.5. Supported cards include but are not limited to:
+
+* NVidia Titan
+* NVidia Titan X
+* NVidia K20
+* NVidia K40
+
##### Download and install Cuda Toolkit 7.0
https://developer.nvidia.com/cuda-toolkit-70
@@ -227,7 +260,7 @@ Notes : You need to install
Follow installation instructions [here](http://docs.scipy.org/doc/numpy/user/install.html).
-### Create the pip package and install
+### Create the pip package and install {#create-pip}
```sh
$ bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
@@ -238,7 +271,7 @@ $ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
$ pip install /tmp/tensorflow_pkg/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
```
-### Train your first TensorFlow neural net model
+## Train your first TensorFlow neural net model
From the root of your source tree, run:
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md
index 5c6243cd9c..8702569f75 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/index.md
+++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md
@@ -127,10 +127,9 @@ To do this for the `ZeroOut` op, add the following to `zero_out.cc`:
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
```
-TODO: instructions or pointer to building TF
-
-At this point, the Tensorflow system can reference and use the Op when
-requested.
+Once you
+[build and reinstall TensorFlow](../../get_started/os_setup.md#create-pip), the
+Tensorflow system can reference and use the Op when requested.
## Generate the client wrapper <div class="md-anchor" id="AUTOGENERATED-generate-the-client-wrapper">{#AUTOGENERATED-generate-the-client-wrapper}</div>
### The Python Op wrapper <div class="md-anchor" id="AUTOGENERATED-the-python-op-wrapper">{#AUTOGENERATED-the-python-op-wrapper}</div>
diff --git a/tensorflow/g3doc/how_tos/variables/index.md b/tensorflow/g3doc/how_tos/variables/index.md
index 4ad8f8a266..26b19b3ae1 100644
--- a/tensorflow/g3doc/how_tos/variables/index.md
+++ b/tensorflow/g3doc/how_tos/variables/index.md
@@ -101,7 +101,7 @@ w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")
The convenience function `tf.initialize_all_variables()` adds an Op to
initialize *all variables* in the model. You can also pass it an explicit list
of variables to initialize. See the
-[Variables Documentation](../../api_docs/python/state_op.md) for more options,
+[Variables Documentation](../../api_docs/python/state_ops.md) for more options,
including checking if variables are initialized.
## Saving and Restoring
diff --git a/tensorflow/g3doc/resources/faq.md b/tensorflow/g3doc/resources/faq.md
index 2bd485e7f9..a2b9a58e08 100644
--- a/tensorflow/g3doc/resources/faq.md
+++ b/tensorflow/g3doc/resources/faq.md
@@ -6,18 +6,18 @@ answer on one of the TensorFlow [community resources](index.md).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
- * [Building a TensorFlow graph](#AUTOGENERATED-building-a-tensorflow-graph)
- * [Running a TensorFlow computation](#AUTOGENERATED-running-a-tensorflow-computation)
- * [Variables](#AUTOGENERATED-variables)
- * [Tensor shapes](#AUTOGENERATED-tensor-shapes)
- * [TensorBoard](#AUTOGENERATED-tensorboard)
- * [Extending TensorFlow](#AUTOGENERATED-extending-tensorflow)
- * [Miscellaneous](#AUTOGENERATED-miscellaneous)
+* [Building a TensorFlow graph](#AUTOGENERATED-building-a-tensorflow-graph)
+* [Running a TensorFlow computation](#AUTOGENERATED-running-a-tensorflow-computation)
+* [Variables](#AUTOGENERATED-variables)
+* [Tensor shapes](#AUTOGENERATED-tensor-shapes)
+* [TensorBoard](#AUTOGENERATED-tensorboard)
+* [Extending TensorFlow](#AUTOGENERATED-extending-tensorflow)
+* [Miscellaneous](#AUTOGENERATED-miscellaneous)
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
-### Building a TensorFlow graph <div class="md-anchor" id="AUTOGENERATED-building-a-tensorflow-graph">{#AUTOGENERATED-building-a-tensorflow-graph}</div>
+## Building a TensorFlow graph <div class="md-anchor" id="AUTOGENERATED-building-a-tensorflow-graph">{#AUTOGENERATED-building-a-tensorflow-graph}</div>
See also the
[API documentation on building graphs](../api_docs/python/framework.md).
@@ -55,7 +55,7 @@ uses multiple GPUs.
TensorFlow supports a variety of different data types and tensor shapes. See the
[ranks, shapes, and types reference](dims_types.md) for more details.
-### Running a TensorFlow computation <div class="md-anchor" id="AUTOGENERATED-running-a-tensorflow-computation">{#AUTOGENERATED-running-a-tensorflow-computation}</div>
+## Running a TensorFlow computation <div class="md-anchor" id="AUTOGENERATED-running-a-tensorflow-computation">{#AUTOGENERATED-running-a-tensorflow-computation}</div>
See also the
[API documentation on running graphs](../api_docs/python/client.md).
@@ -175,7 +175,7 @@ for
[using `QueueRunner` objects to drive queues and readers](../how_tos/reading_data/index.md#QueueRunners)
for more information on how to use them.
-### Variables <div class="md-anchor" id="AUTOGENERATED-variables">{#AUTOGENERATED-variables}</div>
+## Variables <div class="md-anchor" id="AUTOGENERATED-variables">{#AUTOGENERATED-variables}</div>
See also the how-to documentation on [variables](../how_tos/variables/index.md)
and [variable scopes](../how_tos/variable_scope/index.md), and
@@ -196,7 +196,7 @@ operations to a variable are allowed to run with no mutual exclusion. To acquire
a lock when assigning to a variable, pass `use_locking=True` to
[`Variable.assign()`](../api_docs/python/state_ops.md#Variable.assign).
-### Tensor shapes <div class="md-anchor" id="AUTOGENERATED-tensor-shapes">{#AUTOGENERATED-tensor-shapes}</div>
+## Tensor shapes <div class="md-anchor" id="AUTOGENERATED-tensor-shapes">{#AUTOGENERATED-tensor-shapes}</div>
See also the
[`TensorShape` API documentation](../api_docs/python/framework.md#TensorShape).
@@ -248,7 +248,7 @@ to encode the batch size as a Python constant, but instead to use a symbolic
[`tf.placeholder(..., shape=[None, ...])`](../api_docs/python/io_ops.md#placeholder). The
`None` element of the shape corresponds to a variable-sized dimension.
-### TensorBoard <div class="md-anchor" id="AUTOGENERATED-tensorboard">{#AUTOGENERATED-tensorboard}</div>
+## TensorBoard <div class="md-anchor" id="AUTOGENERATED-tensorboard">{#AUTOGENERATED-tensorboard}</div>
See also the
[how-to documentation on TensorBoard](../how_tos/graph_viz/index.md).
@@ -260,7 +260,7 @@ of these summaries to a log directory. Then, startup TensorBoard using
<SOME_COMMAND> and pass the --logdir flag so that it points to your
log directory. For more details, see <YET_UNWRITTEN_TENSORBOARD_TUTORIAL>.
-### Extending TensorFlow <div class="md-anchor" id="AUTOGENERATED-extending-tensorflow">{#AUTOGENERATED-extending-tensorflow}</div>
+## Extending TensorFlow <div class="md-anchor" id="AUTOGENERATED-extending-tensorflow">{#AUTOGENERATED-extending-tensorflow}</div>
See also the how-to documentation for
[adding a new operation to TensorFlow](../how_tos/adding_an_op/index.md).
@@ -293,7 +293,7 @@ how-to documentation for
[adding an op with a list of inputs or outputs](../how_tos/adding_an_op/index.md#list-input-output)
for more details of how to define these different input types.
-### Miscellaneous <div class="md-anchor" id="AUTOGENERATED-miscellaneous">{#AUTOGENERATED-miscellaneous}</div>
+## Miscellaneous <div class="md-anchor" id="AUTOGENERATED-miscellaneous">{#AUTOGENERATED-miscellaneous}</div>
#### Does TensorFlow work with Python 3?
diff --git a/tensorflow/g3doc/resources/uses.md b/tensorflow/g3doc/resources/uses.md
index f73d4e92d7..cc212886c5 100644
--- a/tensorflow/g3doc/resources/uses.md
+++ b/tensorflow/g3doc/resources/uses.md
@@ -17,7 +17,7 @@ Listed below are some of the many uses of TensorFlow.
* **Inception Image Classification Model**
* **Organization**: Google
- * **Description**: Baseline model and follow on research into highly accurate computer vision models, starting with the model that won the 2014 Imagenet iamge classification challenge
+ * **Description**: Baseline model and follow on research into highly accurate computer vision models, starting with the model that won the 2014 Imagenet image classification challenge
* **More Info**: Baseline model described in [Arxiv paper](http://arxiv.org/abs/1409.4842)
* **SmartReply**
@@ -33,6 +33,6 @@ Listed below are some of the many uses of TensorFlow.
* **On-Device Computer Vision for OCR**
* **Organization**: Google
- * **Description**: On-device computer vision model to do optical character recoignition to enable real-time translation.
+ * **Description**: On-device computer vision model to do optical character recognition to enable real-time translation.
* **More info**: [Google Research blog post](http://googleresearch.blogspot.com/2015/07/how-google-translate-squeezes-deep.html)
}
diff --git a/tensorflow/g3doc/tutorials/deep_cnn/index.md b/tensorflow/g3doc/tutorials/deep_cnn/index.md
index 0bbba2cb40..929e1b3047 100644
--- a/tensorflow/g3doc/tutorials/deep_cnn/index.md
+++ b/tensorflow/g3doc/tutorials/deep_cnn/index.md
@@ -1,4 +1,4 @@
-# Convolutional Neural Networks for Object Recognition
+# Convolutional Neural Networks
**NOTE:** This tutorial is intended for *advanced* users of TensorFlow
and assumes expertise and experience in machine learning.
@@ -264,7 +264,7 @@ in `cifar10.py`.
`cifar10_train.py` periodically [saves](../../api_docs/python/state_ops.md#Saver)
all model parameters in
-[checkpoint files](../../how_tos/variables.md#saving-and-restoring)
+[checkpoint files](../../how_tos/variables/index.md#saving-and-restoring)
but it does *not* evaluate the model. The checkpoint file
will be used by `cifar10_eval.py` to measure the predictive
performance (see [Evaluating a Model](#evaluating-a-model) below).
diff --git a/tensorflow/g3doc/tutorials/index.md b/tensorflow/g3doc/tutorials/index.md
index 470645db51..202b87c73c 100644
--- a/tensorflow/g3doc/tutorials/index.md
+++ b/tensorflow/g3doc/tutorials/index.md
@@ -1,7 +1,7 @@
# Overview
-## ML for Beginners
+## MNIST For ML Beginners
If you're new to machine learning, we recommend starting here. You'll learn
about a classic problem, handwritten digit classification (MNIST), and get a
@@ -10,7 +10,7 @@ gentle introduction to multiclass classification.
[View Tutorial](mnist/beginners/index.md)
-## MNIST for Pros
+## Deep MNIST for Experts
If you're already familiar with other deep learning software packages, and are
already familiar with MNIST, this tutorial with give you a very brief primer on
diff --git a/tensorflow/g3doc/tutorials/mandelbrot/index.md b/tensorflow/g3doc/tutorials/mandelbrot/index.md
index 7c6adcb4e8..b3d5a185f9 100755
--- a/tensorflow/g3doc/tutorials/mandelbrot/index.md
+++ b/tensorflow/g3doc/tutorials/mandelbrot/index.md
@@ -1,4 +1,4 @@
-
+# Mandelbrot Set
```
#Import libraries for simulation
diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
index 398eca5f18..fff7484959 100644
--- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
@@ -1,4 +1,4 @@
-# MNIST Softmax Regression (For Beginners)
+# MNIST For ML Beginners
*This tutorial is intended for readers who are new to both machine learning and
TensorFlow. If you already
diff --git a/tensorflow/g3doc/tutorials/mnist/download/index.md b/tensorflow/g3doc/tutorials/mnist/download/index.md
index dc11e727d8..df6245df78 100644
--- a/tensorflow/g3doc/tutorials/mnist/download/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/download/index.md
@@ -1,4 +1,4 @@
-# Downloading MNIST
+# MNIST Data Download
Code: [tensorflow/g3doc/tutorials/mnist/](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/)
diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md
index cb0292586b..34853ccf66 100644
--- a/tensorflow/g3doc/tutorials/mnist/pros/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md
@@ -1,4 +1,4 @@
-# MNIST Deep Learning Example (For Experts)
+# Deep MNIST for Experts
TensorFlow is a powerful library for doing large-scale numerical computation.
One of the tasks at which it excels is implementing and training deep neural
diff --git a/tensorflow/g3doc/tutorials/mnist/tf/index.md b/tensorflow/g3doc/tutorials/mnist/tf/index.md
index 86f3296287..5ce996af12 100644
--- a/tensorflow/g3doc/tutorials/mnist/tf/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/tf/index.md
@@ -1,4 +1,4 @@
-# Handwritten Digit Classification
+# TensorFlow Mechanics 101
Code: [tensorflow/g3doc/tutorials/mnist/](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/)
diff --git a/tensorflow/g3doc/tutorials/pdes/index.md b/tensorflow/g3doc/tutorials/pdes/index.md
index 1f29e4037c..26f36d5536 100755
--- a/tensorflow/g3doc/tutorials/pdes/index.md
+++ b/tensorflow/g3doc/tutorials/pdes/index.md
@@ -1,3 +1,4 @@
+# Partial Differential Equations
## Basic Setup
diff --git a/tensorflow/g3doc/tutorials/seq2seq/index.md b/tensorflow/g3doc/tutorials/seq2seq/index.md
index e421c814aa..3eec2a2ba8 100644
--- a/tensorflow/g3doc/tutorials/seq2seq/index.md
+++ b/tensorflow/g3doc/tutorials/seq2seq/index.md
@@ -1,4 +1,4 @@
-# Sequence-to-Sequence Models: Learning to Translate
+# Sequence-to-Sequence Models
Recurrent neural networks can learn to model language, as already discussed
in the [RNN Tutorial](../recurrent/index.md)
diff --git a/tensorflow/g3doc/tutorials/word2vec/index.md b/tensorflow/g3doc/tutorials/word2vec/index.md
index 832c7c166b..290ff3627f 100644
--- a/tensorflow/g3doc/tutorials/word2vec/index.md
+++ b/tensorflow/g3doc/tutorials/word2vec/index.md
@@ -1,4 +1,4 @@
-# Learning Vector Representations of Words
+# Vector Representations of Words
In this tutorial we look at the word2vec model by
[Mikolov et al.](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf).
diff --git a/tensorflow/models/rnn/BUILD b/tensorflow/models/rnn/BUILD
new file mode 100644
index 0000000000..a88d48fd42
--- /dev/null
+++ b/tensorflow/models/rnn/BUILD
@@ -0,0 +1,106 @@
+# Description:
+# Example RNN models, including language models and sequence-to-sequence models.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("/tensorflow/tensorflow", "cuda_py_tests")
+
+py_library(
+ name = "linear",
+ srcs = [
+ "linear.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "linear_test",
+ size = "small",
+ srcs = ["linear_test.py"],
+ deps = [
+ ":linear",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "rnn_cell",
+ srcs = [
+ "rnn_cell.py",
+ ],
+ deps = [
+ ":linear",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "rnn_cell_test",
+ size = "small",
+ srcs = ["rnn_cell_test.py"],
+ deps = [
+ ":rnn_cell",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "rnn",
+ srcs = [
+ "rnn.py",
+ ],
+ deps = [
+ ":rnn_cell",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cuda_py_tests(
+ name = "rnn_tests",
+ srcs = [
+ "rnn_test.py",
+ ],
+ additional_deps = [
+ ":rnn",
+ ],
+)
+
+py_library(
+ name = "seq2seq",
+ srcs = [
+ "seq2seq.py",
+ ],
+ deps = [
+ ":rnn",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "seq2seq_test",
+ srcs = [
+ "seq2seq_test.py",
+ ],
+ deps = [
+ ":seq2seq",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/models/rnn/README.md b/tensorflow/models/rnn/README.md
new file mode 100644
index 0000000000..227c226c3a
--- /dev/null
+++ b/tensorflow/models/rnn/README.md
@@ -0,0 +1,21 @@
+This directory contains functions for creating recurrent neural networks
+and sequence-to-sequence models. Detailed instructions on how to get started
+and use them are available in the tutorials.
+
+* [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/)
+* [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/)
+
+Here is a short overview of what is in this directory.
+
+File | What's in it?
+--- | ---
+`linear.py` | Basic helper functions for creating linear layers.
+`linear_test.py` | Unit tests for `linear.py`.
+`rnn_cell.py` | Cells for recurrent neural networks, e.g., LSTM.
+`rnn_cell_test.py` | Unit tests for `rnn_cell.py`.
+`rnn.py` | Functions for building recurrent neural networks.
+`rnn_test.py` | Unit tests for `rnn.py`.
+`seq2seq.py` | Functions for building sequence-to-sequence models.
+`seq2seq_test.py` | Unit tests for `seq2seq.py`.
+`ptb/` | PTB language model, see the [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/)
+`translate/` | Translation model, see the [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/)
diff --git a/tensorflow/models/rnn/__init__.py b/tensorflow/models/rnn/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/models/rnn/__init__.py
diff --git a/tensorflow/models/rnn/linear.py b/tensorflow/models/rnn/linear.py
new file mode 100644
index 0000000000..96278e73e4
--- /dev/null
+++ b/tensorflow/models/rnn/linear.py
@@ -0,0 +1,49 @@
+"""Basic linear combinations that implicitly generate variables."""
+
+import tensorflow as tf
+
+
+def linear(args, output_size, bias, bias_start=0.0, scope=None):
+ """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
+
+ Args:
+ args: a 2D Tensor or a list of 2D, batch x n, Tensors.
+ output_size: int, second dimension of W[i].
+ bias: boolean, whether to add a bias term or not.
+ bias_start: starting value to initialize the bias; 0 by default.
+ scope: VariableScope for the created subgraph; defaults to "Linear".
+
+ Returns:
+ A 2D Tensor with shape [batch x output_size] equal to
+ sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
+
+ Raises:
+ ValueError: if some of the arguments has unspecified or wrong shape.
+ """
+ assert args
+ if not isinstance(args, (list, tuple)):
+ args = [args]
+
+ # Calculate the total size of arguments on dimension 1.
+ total_arg_size = 0
+ shapes = [a.get_shape().as_list() for a in args]
+ for shape in shapes:
+ if len(shape) != 2:
+ raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
+ if not shape[1]:
+ raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
+ else:
+ total_arg_size += shape[1]
+
+ # Now the computation.
+ with tf.variable_scope(scope or "Linear"):
+ matrix = tf.get_variable("Matrix", [total_arg_size, output_size])
+ if len(args) == 1:
+ res = tf.matmul(args[0], matrix)
+ else:
+ res = tf.matmul(tf.concat(1, args), matrix)
+ if not bias:
+ return res
+ bias_term = tf.get_variable("Bias", [output_size],
+ initializer=tf.constant_initializer(bias_start))
+ return res + bias_term
diff --git a/tensorflow/models/rnn/linear_test.py b/tensorflow/models/rnn/linear_test.py
new file mode 100644
index 0000000000..93ef10144f
--- /dev/null
+++ b/tensorflow/models/rnn/linear_test.py
@@ -0,0 +1,35 @@
+# pylint: disable=g-bad-import-order,unused-import
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.rnn import linear
+
+
+class LinearTest(tf.test.TestCase):
+
+ def testLinear(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(1.0)):
+ x = tf.zeros([1, 2])
+ l = linear.linear([x], 2, False)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run([l], {x.name: np.array([[1., 2.]])})
+ self.assertAllClose(res[0], [[3.0, 3.0]])
+
+ # Checks prevent you from accidentally creating a shared function.
+ with self.assertRaises(ValueError) as exc:
+ l1 = linear.linear([x], 2, False)
+ self.assertEqual(exc.exception.message[:12], "Over-sharing")
+
+ # But you can create a new one in a new scope and share the variables.
+ with tf.variable_scope("l1") as new_scope:
+ l1 = linear.linear([x], 2, False)
+ with tf.variable_scope(new_scope, reuse=True):
+ linear.linear([l1], 2, False)
+ self.assertEqual(len(tf.trainable_variables()), 2)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/models/rnn/ptb/BUILD b/tensorflow/models/rnn/ptb/BUILD
new file mode 100644
index 0000000000..56d459a0f1
--- /dev/null
+++ b/tensorflow/models/rnn/ptb/BUILD
@@ -0,0 +1,49 @@
+# Description:
+# Python support for TensorFlow.
+
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "reader",
+ srcs = ["reader.py"],
+ deps = ["//tensorflow:tensorflow_py"],
+)
+
+py_test(
+ name = "reader_test",
+ srcs = ["reader_test.py"],
+ deps = [
+ ":reader",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "ptb_word_lm",
+ srcs = [
+ "ptb_word_lm.py",
+ ],
+ deps = [
+ ":reader",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/models/rnn",
+ "//tensorflow/models/rnn:rnn_cell",
+ "//tensorflow/models/rnn:seq2seq",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/models/rnn/ptb/__init__.py b/tensorflow/models/rnn/ptb/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/models/rnn/ptb/__init__.py
diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py
new file mode 100644
index 0000000000..e28d3bf78c
--- /dev/null
+++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py
@@ -0,0 +1,292 @@
+"""Example / benchmark for building a PTB LSTM model.
+
+Trains the model described in:
+(Zaremba, et. al.) Recurrent Neural Network Regularization
+http://arxiv.org/abs/1409.2329
+
+The data required for this example is in the data/ dir of the
+PTB dataset from Tomas Mikolov's webpage:
+
+http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
+
+There are 3 supported model configurations:
+===========================================
+| config | epochs | train | valid | test
+===========================================
+| small | 13 | 37.99 | 121.39 | 115.91
+| medium | 39 | 48.45 | 86.16 | 82.07
+| large | 55 | 37.87 | 82.62 | 78.29
+The exact results may vary depending on the random initialization.
+
+The hyperparameters used in the model:
+- init_scale - the initial scale of the weights
+- learning_rate - the initial value of the learning rate
+- max_grad_norm - the maximum permissible norm of the gradient
+- num_layers - the number of LSTM layers
+- num_steps - the number of unrolled steps of LSTM
+- hidden_size - the number of LSTM units
+- max_epoch - the number of epochs trained with the initial learning rate
+- max_max_epoch - the total number of epochs for training
+- keep_prob - the probability of keeping weights in the dropout layer
+- lr_decay - the decay of the learning rate for each epoch after "max_epoch"
+- batch_size - the batch size
+
+To compile on CPU:
+ bazel build -c opt tensorflow/models/rnn/ptb:ptb_word_lm
+To compile on GPU:
+ bazel build -c opt tensorflow --config=cuda \
+ tensorflow/models/rnn/ptb:ptb_word_lm
+To run:
+ ./bazel-bin/.../ptb_word_lm \
+ --data_path=/tmp/simple-examples/data/ --alsologtostderr
+
+"""
+
+import time
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.rnn import rnn_cell
+from tensorflow.models.rnn import seq2seq
+from tensorflow.models.rnn.ptb import reader
+
+flags = tf.flags
+logging = tf.logging
+
+flags.DEFINE_string(
+ "model", "small",
+ "A type of model. Possible options are: small, medium, large.")
+flags.DEFINE_string("data_path", None, "data_path")
+
+FLAGS = flags.FLAGS
+
+
+class PTBModel(object):
+ """The PTB model."""
+
+ def __init__(self, is_training, config):
+ self.batch_size = batch_size = config.batch_size
+ self.num_steps = num_steps = config.num_steps
+ size = config.hidden_size
+ vocab_size = config.vocab_size
+
+ self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps])
+ self._targets = tf.placeholder(tf.int32, [batch_size, num_steps])
+
+ # Slightly better results can be obtained with forget gate biases
+ # initialized to 1 but the hyperparameters of the model would need to be
+ # different than reported in the paper.
+ lstm_cell = rnn_cell.BasicLSTMCell(size, forget_bias=0.0)
+ if is_training and config.keep_prob < 1:
+ lstm_cell = rnn_cell.DropoutWrapper(
+ lstm_cell, output_keep_prob=config.keep_prob)
+ cell = rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers)
+
+ self._initial_state = cell.zero_state(batch_size, tf.float32)
+
+ with tf.device("/cpu:0"):
+ embedding = tf.get_variable("embedding", [vocab_size, size])
+ inputs = tf.split(
+ 1, num_steps, tf.nn.embedding_lookup(embedding, self._input_data))
+ inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
+
+ if is_training and config.keep_prob < 1:
+ inputs = [tf.nn.dropout(input_, config.keep_prob) for input_ in inputs]
+
+ # Simplified version of tensorflow.models.rnn.rnn.py's rnn().
+ # This builds an unrolled LSTM for tutorial purposes only.
+ # In general, use the rnn() or state_saving_rnn() from rnn.py.
+ #
+ # The alternative version of the code below is:
+ #
+ # from tensorflow.models.rnn import rnn
+ # outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state)
+ outputs = []
+ states = []
+ state = self._initial_state
+ with tf.variable_scope("RNN"):
+ for time_step, input_ in enumerate(inputs):
+ if time_step > 0: tf.get_variable_scope().reuse_variables()
+ (cell_output, state) = cell(input_, state)
+ outputs.append(cell_output)
+ states.append(state)
+
+ output = tf.reshape(tf.concat(1, outputs), [-1, size])
+ logits = tf.nn.xw_plus_b(output,
+ tf.get_variable("softmax_w", [size, vocab_size]),
+ tf.get_variable("softmax_b", [vocab_size]))
+ loss = seq2seq.sequence_loss_by_example([logits],
+ [tf.reshape(self._targets, -1)],
+ [tf.ones([batch_size * num_steps])],
+ vocab_size)
+ self._cost = cost = tf.reduce_sum(loss) / batch_size
+ self._final_state = states[-1]
+
+ if not is_training:
+ return
+
+ self._lr = tf.Variable(0.0, trainable=False)
+ tvars = tf.trainable_variables()
+ grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
+ config.max_grad_norm)
+ optimizer = tf.train.GradientDescentOptimizer(self.lr)
+ self._train_op = optimizer.apply_gradients(zip(grads, tvars))
+
+ def assign_lr(self, session, lr_value):
+ session.run(tf.assign(self.lr, lr_value))
+
+ @property
+ def input_data(self):
+ return self._input_data
+
+ @property
+ def targets(self):
+ return self._targets
+
+ @property
+ def initial_state(self):
+ return self._initial_state
+
+ @property
+ def cost(self):
+ return self._cost
+
+ @property
+ def final_state(self):
+ return self._final_state
+
+ @property
+ def lr(self):
+ return self._lr
+
+ @property
+ def train_op(self):
+ return self._train_op
+
+
+class SmallConfig(object):
+ """Small config."""
+ init_scale = 0.1
+ learning_rate = 1.0
+ max_grad_norm = 5
+ num_layers = 2
+ num_steps = 20
+ hidden_size = 200
+ max_epoch = 4
+ max_max_epoch = 13
+ keep_prob = 1.0
+ lr_decay = 0.5
+ batch_size = 20
+ vocab_size = 10000
+
+
+class MediumConfig(object):
+ """Medium config."""
+ init_scale = 0.05
+ learning_rate = 1.0
+ max_grad_norm = 5
+ num_layers = 2
+ num_steps = 35
+ hidden_size = 650
+ max_epoch = 6
+ max_max_epoch = 39
+ keep_prob = 0.5
+ lr_decay = 0.8
+ batch_size = 20
+ vocab_size = 10000
+
+
+class LargeConfig(object):
+ """Large config."""
+ init_scale = 0.04
+ learning_rate = 1.0
+ max_grad_norm = 10
+ num_layers = 2
+ num_steps = 35
+ hidden_size = 1500
+ max_epoch = 14
+ max_max_epoch = 55
+ keep_prob = 0.35
+ lr_decay = 1 / 1.15
+ batch_size = 20
+ vocab_size = 10000
+
+
+def run_epoch(session, m, data, eval_op, verbose=False):
+ """Runs the model on the given data."""
+ epoch_size = ((len(data) / m.batch_size) - 1) / m.num_steps
+ start_time = time.time()
+ costs = 0.0
+ iters = 0
+ state = m.initial_state.eval()
+ for step, (x, y) in enumerate(reader.ptb_iterator(data, m.batch_size,
+ m.num_steps)):
+ cost, state, _ = session.run([m.cost, m.final_state, eval_op],
+ {m.input_data: x,
+ m.targets: y,
+ m.initial_state: state})
+ costs += cost
+ iters += m.num_steps
+
+ if verbose and step % (epoch_size / 10) == 10:
+ print("%.3f perplexity: %.3f speed: %.0f wps" %
+ (step * 1.0 / epoch_size, np.exp(costs / iters),
+ iters * m.batch_size / (time.time() - start_time)))
+
+ return np.exp(costs / iters)
+
+
+def get_config():
+ if FLAGS.model == "small":
+ return SmallConfig()
+ elif FLAGS.model == "medium":
+ return MediumConfig()
+ elif FLAGS.model == "large":
+ return LargeConfig()
+ else:
+ raise ValueError("Invalid model: %s", FLAGS.model)
+
+
+def main(unused_args):
+ if not FLAGS.data_path:
+ raise ValueError("Must set --data_path to PTB data directory")
+
+ raw_data = reader.ptb_raw_data(FLAGS.data_path)
+ train_data, valid_data, test_data, _ = raw_data
+
+ config = get_config()
+ eval_config = get_config()
+ eval_config.batch_size = 1
+ eval_config.num_steps = 1
+
+ with tf.Graph().as_default(), tf.Session() as session:
+ initializer = tf.random_uniform_initializer(-config.init_scale,
+ config.init_scale)
+ with tf.variable_scope("model", reuse=None, initializer=initializer):
+ m = PTBModel(is_training=True, config=config)
+ with tf.variable_scope("model", reuse=True, initializer=initializer):
+ mvalid = PTBModel(is_training=False, config=config)
+ mtest = PTBModel(is_training=False, config=eval_config)
+
+ tf.initialize_all_variables().run()
+
+ for i in range(config.max_max_epoch):
+ lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
+ m.assign_lr(session, config.learning_rate * lr_decay)
+
+ print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
+ train_perplexity = run_epoch(session, m, train_data, m.train_op,
+ verbose=True)
+ print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
+ valid_perplexity = run_epoch(session, mvalid, valid_data, tf.no_op())
+ print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
+
+ test_perplexity = run_epoch(session, mtest, test_data, tf.no_op())
+ print("Test Perplexity: %.3f" % test_perplexity)
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/tensorflow/models/rnn/ptb/reader.py b/tensorflow/models/rnn/ptb/reader.py
new file mode 100644
index 0000000000..9a0db9c525
--- /dev/null
+++ b/tensorflow/models/rnn/ptb/reader.py
@@ -0,0 +1,105 @@
+# pylint: disable=unused-import,g-bad-import-order
+
+"""Utilities for parsing PTB text files."""
+import collections
+import os
+import sys
+import time
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.platform import gfile
+
+
+def _read_words(filename):
+ with gfile.GFile(filename, "r") as f:
+ return f.read().replace("\n", "<eos>").split()
+
+
+def _build_vocab(filename):
+ data = _read_words(filename)
+
+ counter = collections.Counter(data)
+ count_pairs = sorted(counter.items(), key=lambda x: -x[1])
+
+ words, _ = zip(*count_pairs)
+ word_to_id = dict(zip(words, range(len(words))))
+
+ return word_to_id
+
+
+def _file_to_word_ids(filename, word_to_id):
+ data = _read_words(filename)
+ return [word_to_id[word] for word in data]
+
+
+def ptb_raw_data(data_path=None):
+ """Load PTB raw data from data directory "data_path".
+
+ Reads PTB text files, converts strings to integer ids,
+ and performs mini-batching of the inputs.
+
+ The PTB dataset comes from Tomas Mikolov's webpage:
+
+ http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
+
+ Args:
+ data_path: string path to the directory where simple-examples.tgz has
+ been extracted.
+
+ Returns:
+ tuple (train_data, valid_data, test_data, vocabulary)
+ where each of the data objects can be passed to PTBIterator.
+ """
+
+ train_path = os.path.join(data_path, "ptb.train.txt")
+ valid_path = os.path.join(data_path, "ptb.valid.txt")
+ test_path = os.path.join(data_path, "ptb.test.txt")
+
+ word_to_id = _build_vocab(train_path)
+ train_data = _file_to_word_ids(train_path, word_to_id)
+ valid_data = _file_to_word_ids(valid_path, word_to_id)
+ test_data = _file_to_word_ids(test_path, word_to_id)
+ vocabulary = len(word_to_id)
+ return train_data, valid_data, test_data, vocabulary
+
+
+def ptb_iterator(raw_data, batch_size, num_steps):
+ """Iterate on the raw PTB data.
+
+ This generates batch_size pointers into the raw PTB data, and allows
+ minibatch iteration along these pointers.
+
+ Args:
+ raw_data: one of the raw data outputs from ptb_raw_data.
+ batch_size: int, the batch size.
+ num_steps: int, the number of unrolls.
+
+ Yields:
+ Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
+ The second element of the tuple is the same data time-shifted to the
+ right by one.
+
+ Raises:
+ ValueError: if batch_size or num_steps are too high.
+ """
+ raw_data = np.array(raw_data, dtype=np.int32)
+
+ data_len = len(raw_data)
+ batch_len = data_len / batch_size
+ data = np.zeros([batch_size, batch_len], dtype=np.int32)
+ for i in range(batch_size):
+ data[i] = raw_data[batch_len * i:batch_len * (i + 1)]
+
+ epoch_size = (batch_len - 1) / num_steps
+
+ if epoch_size == 0:
+ raise ValueError("epoch_size == 0, decrease batch_size or num_steps")
+
+ for i in range(epoch_size):
+ x = data[:, i*num_steps:(i+1)*num_steps]
+ y = data[:, i*num_steps+1:(i+1)*num_steps+1]
+ yield (x, y)
diff --git a/tensorflow/models/rnn/ptb/reader_test.py b/tensorflow/models/rnn/ptb/reader_test.py
new file mode 100644
index 0000000000..c722cdb939
--- /dev/null
+++ b/tensorflow/models/rnn/ptb/reader_test.py
@@ -0,0 +1,47 @@
+"""Tests for tensorflow.models.ptb_lstm.ptb_reader."""
+
+import os.path
+
+# pylint: disable=g-bad-import-order,unused-import
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.rnn.ptb import reader
+from tensorflow.python.platform import gfile
+
+
+class PtbReaderTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._string_data = "\n".join(
+ [" hello there i am",
+ " rain as day",
+ " want some cheesy puffs ?"])
+
+ def testPtbRawData(self):
+ tmpdir = tf.test.get_temp_dir()
+ for suffix in "train", "valid", "test":
+ filename = os.path.join(tmpdir, "ptb.%s.txt" % suffix)
+ with gfile.GFile(filename, "w") as fh:
+ fh.write(self._string_data)
+ # Smoke test
+ output = reader.ptb_raw_data(tmpdir)
+ self.assertEqual(len(output), 4)
+
+ def testPtbIterator(self):
+ raw_data = [4, 3, 2, 1, 0, 5, 6, 1, 1, 1, 1, 0, 3, 4, 1]
+ batch_size = 3
+ num_steps = 2
+ output = list(reader.ptb_iterator(raw_data, batch_size, num_steps))
+ self.assertEqual(len(output), 2)
+ o1, o2 = (output[0], output[1])
+ self.assertEqual(o1[0].shape, (batch_size, num_steps))
+ self.assertEqual(o1[1].shape, (batch_size, num_steps))
+ self.assertEqual(o2[0].shape, (batch_size, num_steps))
+ self.assertEqual(o2[1].shape, (batch_size, num_steps))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/models/rnn/rnn.py b/tensorflow/models/rnn/rnn.py
new file mode 100644
index 0000000000..24582bcae7
--- /dev/null
+++ b/tensorflow/models/rnn/rnn.py
@@ -0,0 +1,128 @@
+"""RNN helpers for TensorFlow models."""
+
+import tensorflow as tf
+
+from tensorflow.models.rnn import rnn_cell
+from tensorflow.python.ops import control_flow_ops
+
+
+def rnn(cell, inputs, initial_state=None, dtype=None,
+ sequence_length=None, scope=None):
+ """Creates a recurrent neural network specified by RNNCell "cell".
+
+ The simplest form of RNN network generated is:
+ state = cell.zero_state(...)
+ outputs = []
+ states = []
+ for input_ in inputs:
+ output, state = cell(input_, state)
+ outputs.append(output)
+ states.append(state)
+ return (outputs, states)
+
+ However, a few other options are available:
+
+ An initial state can be provided.
+ If sequence_length is provided, dynamic calculation is performed.
+
+ Dynamic calculation returns, at time t:
+ (t >= max(sequence_length)
+ ? (zeros(output_shape), zeros(state_shape))
+ : cell(input, state)
+
+ Thus saving computational time when unrolling past the max sequence length.
+
+ Args:
+ cell: An instance of RNNCell.
+ inputs: A length T list of inputs, each a vector with shape [batch_size].
+ initial_state: (optional) An initial state for the RNN. This must be
+ a tensor of appropriate type and shape [batch_size x cell.state_size].
+ dtype: (optional) The data type for the initial state. Required if
+ initial_state is not provided.
+ sequence_length: An int64 vector (tensor) size [batch_size].
+ scope: VariableScope for the created subgraph; defaults to "RNN".
+
+ Returns:
+ A pair (outputs, states) where:
+ outputs is a length T list of outputs (one for each input)
+ states is a length T list of states (one state following each input)
+
+ Raises:
+ TypeError: If "cell" is not an instance of RNNCell.
+ ValueError: If inputs is None or an empty list.
+ """
+
+ if not isinstance(cell, rnn_cell.RNNCell):
+ raise TypeError("cell must be an instance of RNNCell")
+ if not isinstance(inputs, list):
+ raise TypeError("inputs must be a list")
+ if not inputs:
+ raise ValueError("inputs must not be empty")
+
+ outputs = []
+ states = []
+ with tf.variable_scope(scope or "RNN"):
+ batch_size = tf.shape(inputs[0])[0]
+ if initial_state is not None:
+ state = initial_state
+ else:
+ if not dtype:
+ raise ValueError("If no initial_state is provided, dtype must be.")
+ state = cell.zero_state(batch_size, dtype)
+
+ if sequence_length: # Prepare variables
+ zero_output_state = (
+ tf.zeros(tf.pack([batch_size, cell.output_size]),
+ inputs[0].dtype),
+ tf.zeros(tf.pack([batch_size, cell.state_size]),
+ state.dtype))
+ max_sequence_length = tf.reduce_max(sequence_length)
+
+ output_state = (None, None)
+ for time, input_ in enumerate(inputs):
+ if time > 0:
+ tf.get_variable_scope().reuse_variables()
+ output_state = cell(input_, state)
+ if sequence_length:
+ (output, state) = control_flow_ops.cond(
+ time >= max_sequence_length,
+ lambda: zero_output_state, lambda: output_state)
+ else:
+ (output, state) = output_state
+
+ outputs.append(output)
+ states.append(state)
+
+ return (outputs, states)
+
+
+def state_saving_rnn(cell, inputs, state_saver, state_name,
+ sequence_length=None, scope=None):
+ """RNN that accepts a state saver for time-truncated RNN calculation.
+
+ Args:
+ cell: An instance of RNNCell.
+ inputs: A length T list of inputs, each a vector with shape [batch_size].
+ state_saver: A StateSaver object.
+ state_name: The name to use with the state_saver.
+ sequence_length: (optional) An int64 vector (tensor) size [batch_size].
+ See the documentation for rnn() for more details about sequence_length.
+ scope: VariableScope for the created subgraph; defaults to "RNN".
+
+ Returns:
+ A pair (outputs, states) where:
+ outputs is a length T list of outputs (one for each input)
+ states is a length T list of states (one state following each input)
+
+ Raises:
+ TypeError: If "cell" is not an instance of RNNCell.
+ ValueError: If inputs is None or an empty list.
+ """
+ initial_state = state_saver.State(state_name)
+ (outputs, states) = rnn(cell, inputs, initial_state=initial_state,
+ sequence_length=sequence_length, scope=scope)
+ save_state = state_saver.SaveState(state_name, states[-1])
+ with tf.control_dependencies([save_state]):
+ outputs[-1] = tf.identity(outputs[-1])
+
+ return (outputs, states)
diff --git a/tensorflow/models/rnn/rnn_cell.py b/tensorflow/models/rnn/rnn_cell.py
new file mode 100644
index 0000000000..55d417fc2b
--- /dev/null
+++ b/tensorflow/models/rnn/rnn_cell.py
@@ -0,0 +1,605 @@
+"""Module for constructing RNN Cells."""
+
+import math
+
+import tensorflow as tf
+
+from tensorflow.models.rnn import linear
+
+
+class RNNCell(object):
+ """Abstract object representing an RNN cell.
+
+ An RNN cell, in the most abstract setting, is anything that has
+ a state -- a vector of floats of size self.state_size -- and performs some
+ operation that takes inputs of size self.input_size. This operation
+ results in an output of size self.output_size and a new state.
+
+ This module provides a number of basic commonly used RNN cells, such as
+ LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number
+ of operators that allow add dropouts, projections, or embeddings for inputs.
+ Constructing multi-layer cells is supported by a super-class, MultiRNNCell,
+ defined later. Every RNNCell must have the properties below and and
+ implement __call__ with the following signature.
+ """
+
+ def __call__(self, inputs, state, scope=None):
+ """Run this RNN cell on inputs, starting from the given state.
+
+ Args:
+ inputs: 2D Tensor with shape [batch_size x self.input_size].
+ state: 2D Tensor with shape [batch_size x self.state_size].
+ scope: VariableScope for the created subgraph; defaults to class name.
+
+ Returns:
+ A pair containing:
+ - Output: A 2D Tensor with shape [batch_size x self.output_size]
+ - New state: A 2D Tensor with shape [batch_size x self.state_size].
+ """
+ raise NotImplementedError("Abstract method")
+
+ @property
+ def input_size(self):
+ """Integer: size of inputs accepted by this cell."""
+ raise NotImplementedError("Abstract method")
+
+ @property
+ def output_size(self):
+ """Integer: size of outputs produced by this cell."""
+ raise NotImplementedError("Abstract method")
+
+ @property
+ def state_size(self):
+ """Integer: size of state used by this cell."""
+ raise NotImplementedError("Abstract method")
+
+ def zero_state(self, batch_size, dtype):
+ """Return state tensor (shape [batch_size x state_size]) filled with 0.
+
+ Args:
+ batch_size: int, float, or unit Tensor representing the batch size.
+ dtype: the data type to use for the state.
+
+ Returns:
+ A 2D Tensor of shape [batch_size x state_size] filled with zeros.
+ """
+ zeros = tf.zeros(tf.pack([batch_size, self.state_size]), dtype=dtype)
+ # The reshape below is a no-op, but it allows shape inference of shape[1].
+ return tf.reshape(zeros, [-1, self.state_size])
+
+
+class BasicRNNCell(RNNCell):
+ """The most basic RNN cell."""
+
+ def __init__(self, num_units):
+ self._num_units = num_units
+
+ @property
+ def input_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
+ with tf.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
+ output = tf.tanh(linear.linear([inputs, state], self._num_units, True))
+ return output, output
+
+
+class GRUCell(RNNCell):
+ """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
+
+ def __init__(self, num_units):
+ self._num_units = num_units
+
+ @property
+ def input_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Gated recurrent unit (GRU) with nunits cells."""
+ with tf.variable_scope(scope or type(self).__name__): # "GRUCell"
+ with tf.variable_scope("Gates"): # Reset gate and update gate.
+ # We start with bias of 1.0 to not reset and not udpate.
+ r, u = tf.split(1, 2, linear.linear([inputs, state],
+ 2 * self._num_units, True, 1.0))
+ r, u = tf.sigmoid(r), tf.sigmoid(u)
+ with tf.variable_scope("Candidate"):
+ c = tf.tanh(linear.linear([inputs, r * state], self._num_units, True))
+ new_h = u * state + (1 - u) * c
+ return new_h, new_h
+
+
+class BasicLSTMCell(RNNCell):
+ """Basic LSTM recurrent network cell.
+
+ The implementation is based on: http://arxiv.org/pdf/1409.2329v5.pdf.
+
+ It does not allow cell clipping, a projection layer, and does not
+ use peep-hole connections: it is the basic baseline.
+
+ Biases of the forget gate are initialized by default to 1 in order to reduce
+ the scale of forgetting in the beginning of the training.
+ """
+
+ def __init__(self, num_units, forget_bias=1.0):
+ self._num_units = num_units
+ self._forget_bias = forget_bias
+
+ @property
+ def input_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ @property
+ def state_size(self):
+ return 2 * self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Long short-term memory cell (LSTM)."""
+ with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
+ # Parameters of gates are concatenated into one multiply for efficiency.
+ c, h = tf.split(1, 2, state)
+ concat = linear.linear([inputs, h], 4 * self._num_units, True)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ i, j, f, o = tf.split(1, 4, concat)
+
+ new_c = c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * tf.tanh(j)
+ new_h = tf.tanh(new_c) * tf.sigmoid(o)
+
+ return new_h, tf.concat(1, [new_c, new_h])
+
+
+class LSTMCell(RNNCell):
+ """Long short-term memory unit (LSTM) recurrent network cell.
+
+ This implementation is based on:
+
+ https://research.google.com/pubs/archive/43905.pdf
+
+ Hasim Sak, Andrew Senior, and Francoise Beaufays.
+ "Long short-term memory recurrent neural network architectures for
+ large scale acoustic modeling." INTERSPEECH, 2014.
+
+ It uses peep-hole connections, optional cell clipping, and an optional
+ projection layer.
+ """
+
+ def __init__(self, num_units, input_size,
+ use_peepholes=False, cell_clip=None,
+ initializer=None, num_proj=None,
+ num_unit_shards=1, num_proj_shards=1):
+ """Initialize the parameters for an LSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell
+ input_size: int, The dimensionality of the inputs into the LSTM cell
+ use_peepholes: bool, set True to enable diagonal/peephole connections.
+ cell_clip: (optional) A float value, if provided the cell state is clipped
+ by this value prior to the cell output activation.
+ initializer: (optional) The initializer to use for the weight and
+ projection matrices.
+ num_proj: (optional) int, The output dimensionality for the projection
+ matrices. If None, no projection is performed.
+ num_unit_shards: How to split the weight matrix. If >1, the weight
+ matrix is stored across num_unit_shards.
+ Note that num_unit_shards must evenly divide num_units * 4.
+ num_proj_shards: How to split the projection matrix. If >1, the
+ projection matrix is stored across num_proj_shards.
+ Note that num_proj_shards must evenly divide num_proj
+ (if num_proj is not None).
+
+ Raises:
+ ValueError: if num_unit_shards doesn't divide 4 * num_units or
+ num_proj_shards doesn't divide num_proj
+ """
+ self._num_units = num_units
+ self._input_size = input_size
+ self._use_peepholes = use_peepholes
+ self._cell_clip = cell_clip
+ self._initializer = initializer
+ self._num_proj = num_proj
+ self._num_unit_shards = num_unit_shards
+ self._num_proj_shards = num_proj_shards
+
+ if (num_units * 4) % num_unit_shards != 0:
+ raise ValueError("num_unit_shards must evently divide 4 * num_units")
+ if num_proj and num_proj % num_proj_shards != 0:
+ raise ValueError("num_proj_shards must evently divide num_proj")
+
+ if num_proj:
+ self._state_size = num_units + num_proj
+ self._output_size = num_proj
+ else:
+ self._state_size = 2 * num_units
+ self._output_size = num_units
+
+ @property
+ def input_size(self):
+ return self._input_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ def __call__(self, input_, state, scope=None):
+ """Run one step of LSTM.
+
+ Args:
+ input_: input Tensor, 2D, batch x num_units.
+ state: state Tensor, 2D, batch x state_size.
+ scope: VariableScope for the created subgraph; defaults to "LSTMCell".
+
+ Returns:
+ A tuple containing:
+ - A 2D, batch x output_dim, Tensor representing the output of the LSTM
+ after reading "input_" when previous state was "state".
+ Here output_dim is:
+ num_proj if num_proj was set,
+ num_units otherwise.
+ - A 2D, batch x state_size, Tensor representing the new state of LSTM
+ after reading "input_" when previous state was "state".
+ """
+ num_proj = self._num_units if self._num_proj is None else self._num_proj
+
+ c_prev = tf.slice(state, [0, 0], [-1, self._num_units])
+ m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj])
+
+ dtype = input_.dtype
+
+ unit_shard_size = (4 * self._num_units) / self._num_unit_shards
+
+ with tf.variable_scope(scope or type(self).__name__): # "LSTMCell"
+ w = tf.concat(
+ 1, [tf.get_variable("W_%d" % i,
+ shape=[self.input_size + num_proj,
+ unit_shard_size],
+ initializer=self._initializer,
+ dtype=dtype)
+ for i in range(self._num_unit_shards)])
+
+ b = tf.get_variable(
+ "B", shape=[4 * self._num_units],
+ initializer=tf.zeros_initializer, dtype=dtype)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ cell_inputs = tf.concat(1, [input_, m_prev])
+ i, j, f, o = tf.split(1, 4, tf.nn.bias_add(tf.matmul(cell_inputs, w), b))
+
+ # Diagonal connections
+ if self._use_peepholes:
+ w_f_diag = tf.get_variable(
+ "W_F_diag", shape=[self._num_units], dtype=dtype)
+ w_i_diag = tf.get_variable(
+ "W_I_diag", shape=[self._num_units], dtype=dtype)
+ w_o_diag = tf.get_variable(
+ "W_O_diag", shape=[self._num_units], dtype=dtype)
+
+ if self._use_peepholes:
+ c = (tf.sigmoid(f + 1 + w_f_diag * c_prev) * c_prev +
+ tf.sigmoid(i + w_i_diag * c_prev) * tf.tanh(j))
+ else:
+ c = (tf.sigmoid(f + 1) * c_prev + tf.sigmoid(i) * tf.tanh(j))
+
+ if self._cell_clip is not None:
+ c = tf.clip_by_value(c, -self._cell_clip, self._cell_clip)
+
+ if self._use_peepholes:
+ m = tf.sigmoid(o + w_o_diag * c) * tf.tanh(c)
+ else:
+ m = tf.sigmoid(o) * tf.tanh(c)
+
+ if self._num_proj is not None:
+ proj_shard_size = self._num_proj / self._num_proj_shards
+ w_proj = tf.concat(
+ 1, [tf.get_variable("W_P_%d" % i,
+ shape=[self._num_units, proj_shard_size],
+ initializer=self._initializer, dtype=dtype)
+ for i in range(self._num_proj_shards)])
+ # TODO(ebrevdo), use matmulsum
+ m = tf.matmul(m, w_proj)
+
+ return m, tf.concat(1, [c, m])
+
+
+class OutputProjectionWrapper(RNNCell):
+ """Operator adding an output projection to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your outputs in time,
+ do the projection on this batch-concated sequence, then split it
+ if needed or directly feed into a softmax.
+ """
+
+ def __init__(self, cell, output_size):
+ """Create a cell with output projection.
+
+ Args:
+ cell: an RNNCell, a projection to output_size is added to it.
+ output_size: integer, the size of the output after projection.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if output_size is not positive.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ if output_size < 1:
+ raise ValueError("Parameter output_size must be > 0: %d." % output_size)
+ self._cell = cell
+ self._output_size = output_size
+
+ @property
+ def input_size(self):
+ return self._cell.input_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell and output projection on inputs, starting from state."""
+ output, res_state = self._cell(inputs, state)
+ # Default scope: "OutputProjectionWrapper"
+ with tf.variable_scope(scope or type(self).__name__):
+ projected = linear.linear(output, self._output_size, True)
+ return projected, res_state
+
+
+class InputProjectionWrapper(RNNCell):
+ """Operator adding an input projection to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your inputs in time,
+ do the projection on this batch-concated sequence, then split it.
+ """
+
+ def __init__(self, cell, input_size):
+ """Create a cell with input projection.
+
+ Args:
+ cell: an RNNCell, a projection of inputs is added before it.
+ input_size: integer, the size of the inputs before projection.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if input_size is not positive.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ if input_size < 1:
+ raise ValueError("Parameter input_size must be > 0: %d." % input_size)
+ self._cell = cell
+ self._input_size = input_size
+
+ @property
+ def input_size(self):
+ return self._input_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the input projection and then the cell."""
+ # Default scope: "InputProjectionWrapper"
+ with tf.variable_scope(scope or type(self).__name__):
+ projected = linear.linear(inputs, self._cell.input_size, True)
+ return self._cell(projected, state)
+
+
+class DropoutWrapper(RNNCell):
+ """Operator adding dropout to inputs and outputs of the given cell."""
+
+ def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
+ seed=None):
+ """Create a cell with added input and/or output dropout.
+
+ Dropout is never used on the state.
+
+ Args:
+ cell: an RNNCell, a projection to output_size is added to it.
+ input_keep_prob: unit Tensor or float between 0 and 1, input keep
+ probability; if it is float and 1, no input dropout will be added.
+ output_keep_prob: unit Tensor or float between 0 and 1, output keep
+ probability; if it is float and 1, no output dropout will be added.
+ seed: (optional) integer, the randomness seed.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if keep_prob is not between 0 and 1.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not a RNNCell.")
+ if (isinstance(input_keep_prob, float) and
+ not (input_keep_prob >= 0.0 and input_keep_prob <= 1.0)):
+ raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d"
+ % input_keep_prob)
+ if (isinstance(output_keep_prob, float) and
+ not (output_keep_prob >= 0.0 and output_keep_prob <= 1.0)):
+ raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d"
+ % output_keep_prob)
+ self._cell = cell
+ self._input_keep_prob = input_keep_prob
+ self._output_keep_prob = output_keep_prob
+ self._seed = seed
+
+ @property
+ def input_size(self):
+ return self._cell.input_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ def __call__(self, inputs, state):
+ """Run the cell with the declared dropouts."""
+ if (not isinstance(self._input_keep_prob, float) or
+ self._input_keep_prob < 1):
+ inputs = tf.nn.dropout(inputs, self._input_keep_prob, seed=self._seed)
+ output, new_state = self._cell(inputs, state)
+ if (not isinstance(self._output_keep_prob, float) or
+ self._output_keep_prob < 1):
+ output = tf.nn.dropout(output, self._output_keep_prob, seed=self._seed)
+ return output, new_state
+
+
+class EmbeddingWrapper(RNNCell):
+ """Operator adding input embedding to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your inputs in time,
+ do the embedding on this batch-concated sequence, then split it and
+ feed into your RNN.
+ """
+
+ def __init__(self, cell, embedding_classes=0, embedding=None,
+ initializer=None):
+ """Create a cell with an added input embedding.
+
+ Args:
+ cell: an RNNCell, an embedding will be put before its inputs.
+ embedding_classes: integer, how many symbols will be embedded.
+ embedding: Variable, the embedding to use; if None, a new embedding
+ will be created; if set, then embedding_classes is not required.
+ initializer: an initializer to use when creating the embedding;
+ if None, the initializer from variable scope or a default one is used.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if embedding_classes is not positive.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ if embedding_classes < 1 and embedding is None:
+ raise ValueError("Pass embedding or embedding_classes must be > 0: %d."
+ % embedding_classes)
+ if embedding_classes > 0 and embedding is not None:
+ if embedding.size[0] != embedding_classes:
+ raise ValueError("You declared embedding_classes=%d but passed an "
+ "embedding for %d classes." % (embedding.size[0],
+ embedding_classes))
+ if embedding.size[1] != cell.input_size:
+ raise ValueError("You passed embedding with output size %d and a cell"
+ " that accepts size %d." % (embedding.size[1],
+ cell.input_size))
+ self._cell = cell
+ self._embedding_classes = embedding_classes
+ self._embedding = embedding
+ self._initializer = initializer
+
+ @property
+ def input_size(self):
+ return 1
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell on embedded inputs."""
+ with tf.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper"
+ with tf.device("/cpu:0"):
+ if self._embedding:
+ embedding = self._embedding
+ else:
+ if self._initializer:
+ initializer = self._initializer
+ elif tf.get_variable_scope().initializer:
+ initializer = tf.get_variable_scope().initializer
+ else:
+ # Default initializer for embeddings should have variance=1.
+ sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
+ initializer = tf.random_uniform_initializer(-sqrt3, sqrt3)
+ embedding = tf.get_variable("embedding", [self._embedding_classes,
+ self._cell.input_size],
+ initializer=initializer)
+ embedded = tf.nn.embedding_lookup(embedding, tf.reshape(inputs, [-1]))
+ return self._cell(embedded, state)
+
+
+class MultiRNNCell(RNNCell):
+ """RNN cell composed sequentially of multiple simple cells."""
+
+ def __init__(self, cells):
+ """Create a RNN cell composed sequentially of a number of RNNCells.
+
+ Args:
+ cells: list of RNNCells that will be composed in this order.
+
+ Raises:
+ ValueError: if cells is empty (not allowed) or if their sizes don't match.
+ """
+ if not cells:
+ raise ValueError("Must specify at least one cell for MultiRNNCell.")
+ for i in xrange(len(cells) - 1):
+ if cells[i + 1].input_size != cells[i].output_size:
+ raise ValueError("In MultiRNNCell, the input size of each next"
+ " cell must match the output size of the previous one."
+ " Mismatched output size in cell %d." % i)
+ self._cells = cells
+
+ @property
+ def input_size(self):
+ return self._cells[0].input_size
+
+ @property
+ def output_size(self):
+ return self._cells[-1].output_size
+
+ @property
+ def state_size(self):
+ return sum([cell.state_size for cell in self._cells])
+
+ def __call__(self, inputs, state, scope=None):
+ """Run this multi-layer cell on inputs, starting from state."""
+ with tf.variable_scope(scope or type(self).__name__): # "MultiRNNCell"
+ cur_state_pos = 0
+ cur_inp = inputs
+ new_states = []
+ for i, cell in enumerate(self._cells):
+ with tf.variable_scope("Cell%d" % i):
+ cur_state = tf.slice(state, [0, cur_state_pos], [-1, cell.state_size])
+ cur_state_pos += cell.state_size
+ cur_inp, new_state = cell(cur_inp, cur_state)
+ new_states.append(new_state)
+ return cur_inp, tf.concat(1, new_states)
diff --git a/tensorflow/models/rnn/rnn_cell_test.py b/tensorflow/models/rnn/rnn_cell_test.py
new file mode 100644
index 0000000000..8b4b209028
--- /dev/null
+++ b/tensorflow/models/rnn/rnn_cell_test.py
@@ -0,0 +1,154 @@
+"""Tests for RNN cells."""
+
+# pylint: disable=g-bad-import-order,unused-import
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.rnn import rnn_cell
+
+
+class RNNCellTest(tf.test.TestCase):
+
+ def testBasicRNNCell(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 2])
+ m = tf.zeros([1, 2])
+ g, _ = rnn_cell.BasicRNNCell(2)(x, m)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run([g], {x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])})
+ self.assertEqual(res[0].shape, (1, 2))
+
+ def testGRUCell(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 2])
+ m = tf.zeros([1, 2])
+ g, _ = rnn_cell.GRUCell(2)(x, m)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run([g], {x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])})
+ # Smoke test
+ self.assertAllClose(res[0], [[0.175991, 0.175991]])
+
+ def testBasicLSTMCell(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 2])
+ m = tf.zeros([1, 8])
+ g, out_m = rnn_cell.MultiRNNCell([rnn_cell.BasicLSTMCell(2)] * 2)(x, m)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run([g, out_m], {x.name: np.array([[1., 1.]]),
+ m.name: 0.1 * np.ones([1, 8])})
+ self.assertEqual(len(res), 2)
+ # The numbers in results were not calculated, this is just a smoke test.
+ self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
+ expected_mem = np.array([[0.68967271, 0.68967271,
+ 0.44848421, 0.44848421,
+ 0.39897051, 0.39897051,
+ 0.24024698, 0.24024698]])
+ self.assertAllClose(res[1], expected_mem)
+
+ def testLSTMCell(self):
+ with self.test_session() as sess:
+ num_units = 8
+ num_proj = 6
+ state_size = num_units + num_proj
+ batch_size = 3
+ input_size = 2
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([batch_size, input_size])
+ m = tf.zeros([batch_size, state_size])
+ output, state = rnn_cell.LSTMCell(
+ num_units=num_units, input_size=input_size, num_proj=num_proj)(x, m)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run([output, state],
+ {x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
+ m.name: 0.1 * np.ones((batch_size, state_size))})
+ self.assertEqual(len(res), 2)
+ # The numbers in results were not calculated, this is mostly just a
+ # smoke test.
+ self.assertEqual(res[0].shape, (batch_size, num_proj))
+ self.assertEqual(res[1].shape, (batch_size, state_size))
+ # Different inputs so different outputs and states
+ for i in range(1, batch_size):
+ self.assertTrue(
+ float(np.linalg.norm((res[0][0,:] - res[0][i,:]))) > 1e-6)
+ self.assertTrue(
+ float(np.linalg.norm((res[1][0,:] - res[1][i,:]))) > 1e-6)
+
+ def testOutputProjectionWrapper(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 3])
+ m = tf.zeros([1, 3])
+ cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(3), 2)
+ g, new_m = cell(x, m)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run([g, new_m], {x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1, 0.1]])})
+ self.assertEqual(res[1].shape, (1, 3))
+ # The numbers in results were not calculated, this is just a smoke test.
+ self.assertAllClose(res[0], [[0.231907, 0.231907]])
+
+ def testInputProjectionWrapper(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 2])
+ m = tf.zeros([1, 3])
+ cell = rnn_cell.InputProjectionWrapper(rnn_cell.GRUCell(3), 2)
+ g, new_m = cell(x, m)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run([g, new_m], {x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1, 0.1]])})
+ self.assertEqual(res[1].shape, (1, 3))
+ # The numbers in results were not calculated, this is just a smoke test.
+ self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
+
+ def testDropoutWrapper(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 3])
+ m = tf.zeros([1, 3])
+ keep = tf.zeros([1]) + 1
+ g, new_m = rnn_cell.DropoutWrapper(rnn_cell.GRUCell(3),
+ keep, keep)(x, m)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run([g, new_m], {x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1, 0.1]])})
+ self.assertEqual(res[1].shape, (1, 3))
+ # The numbers in results were not calculated, this is just a smoke test.
+ self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
+
+ def testEmbeddingWrapper(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 1], dtype=tf.int32)
+ m = tf.zeros([1, 2])
+ g, new_m = rnn_cell.EmbeddingWrapper(rnn_cell.GRUCell(2), 3)(x, m)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run([g, new_m], {x.name: np.array([[1]]),
+ m.name: np.array([[0.1, 0.1]])})
+ self.assertEqual(res[1].shape, (1, 2))
+ # The numbers in results were not calculated, this is just a smoke test.
+ self.assertAllClose(res[0], [[0.17139, 0.17139]])
+
+ def testMultiRNNCell(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 2])
+ m = tf.zeros([1, 4])
+ _, ml = rnn_cell.MultiRNNCell([rnn_cell.GRUCell(2)] * 2)(x, m)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run(ml, {x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1, 0.1, 0.1]])})
+ # The numbers in results were not calculated, this is just a smoke test.
+ self.assertAllClose(res, [[0.175991, 0.175991,
+ 0.13248, 0.13248]])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/models/rnn/rnn_test.py b/tensorflow/models/rnn/rnn_test.py
new file mode 100644
index 0000000000..378315d296
--- /dev/null
+++ b/tensorflow/models/rnn/rnn_test.py
@@ -0,0 +1,472 @@
+"""Tests for rnn module."""
+
+# pylint: disable=g-bad-import-order,unused-import
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.rnn import rnn
+from tensorflow.models.rnn import rnn_cell
+
+
+class Plus1RNNCell(rnn_cell.RNNCell):
+ """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
+
+ @property
+ def output_size(self):
+ return 5
+
+ @property
+ def state_size(self):
+ return 5
+
+ def __call__(self, input_, state):
+ return (input_ + 1, state + 1)
+
+
+class TestStateSaver(object):
+
+ def __init__(self, batch_size, state_size):
+ self._batch_size = batch_size
+ self._state_size = state_size
+
+ def State(self, _):
+ return tf.zeros(tf.pack([self._batch_size, self._state_size]))
+
+ def SaveState(self, _, state):
+ self.saved_state = state
+ return tf.identity(state)
+
+
+class RNNTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._seed = 23489
+ np.random.seed(self._seed)
+
+ def testRNN(self):
+ cell = Plus1RNNCell()
+ batch_size = 2
+ inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10
+ outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32)
+ self.assertEqual(len(outputs), len(inputs))
+ for out, inp in zip(outputs, inputs):
+ self.assertEqual(out.get_shape(), inp.get_shape())
+ self.assertEqual(out.dtype, inp.dtype)
+
+ with self.test_session(use_gpu=False) as sess:
+ input_value = np.random.randn(batch_size, 5)
+ values = sess.run(outputs + [states[-1]],
+ feed_dict={inputs[0]: input_value})
+
+ # Outputs
+ for v in values[:-1]:
+ self.assertAllClose(v, input_value + 1.0)
+
+ # Final state
+ self.assertAllClose(
+ values[-1], 10.0*np.ones((batch_size, 5), dtype=np.float32))
+
+ def testDropout(self):
+ cell = Plus1RNNCell()
+ full_dropout_cell = rnn_cell.DropoutWrapper(
+ cell, input_keep_prob=1e-12, seed=0)
+ batch_size = 2
+ inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10
+ with tf.variable_scope("share_scope"):
+ outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32)
+ with tf.variable_scope("drop_scope"):
+ dropped_outputs, _ = rnn.rnn(full_dropout_cell, inputs, dtype=tf.float32)
+ self.assertEqual(len(outputs), len(inputs))
+ for out, inp in zip(outputs, inputs):
+ self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list())
+ self.assertEqual(out.dtype, inp.dtype)
+
+ with self.test_session(use_gpu=False) as sess:
+ input_value = np.random.randn(batch_size, 5)
+ values = sess.run(outputs + [states[-1]],
+ feed_dict={inputs[0]: input_value})
+ full_dropout_values = sess.run(dropped_outputs,
+ feed_dict={inputs[0]: input_value})
+
+ for v in values[:-1]:
+ self.assertAllClose(v, input_value + 1.0)
+ for d_v in full_dropout_values[:-1]: # Add 1.0 to dropped_out (all zeros)
+ self.assertAllClose(d_v, np.ones_like(input_value))
+
+ def testDynamicCalculation(self):
+ cell = Plus1RNNCell()
+ sequence_length = tf.placeholder(tf.int64)
+ batch_size = 2
+ inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10
+ with tf.variable_scope("drop_scope"):
+ dynamic_outputs, dynamic_states = rnn.rnn(
+ cell, inputs, sequence_length=sequence_length, dtype=tf.float32)
+ self.assertEqual(len(dynamic_outputs), len(inputs))
+ self.assertEqual(len(dynamic_states), len(inputs))
+
+ with self.test_session(use_gpu=False) as sess:
+ input_value = np.random.randn(batch_size, 5)
+ dynamic_values = sess.run(dynamic_outputs,
+ feed_dict={inputs[0]: input_value,
+ sequence_length: [2, 3]})
+ dynamic_state_values = sess.run(dynamic_states,
+ feed_dict={inputs[0]: input_value,
+ sequence_length: [2, 3]})
+
+ # fully calculated for t = 0, 1, 2
+ for v in dynamic_values[:3]:
+ self.assertAllClose(v, input_value + 1.0)
+ for vi, v in enumerate(dynamic_state_values[:3]):
+ self.assertAllEqual(v, 1.0 * (vi + 1) * np.ones((batch_size, 5)))
+ # zeros for t = 3+
+ for v in dynamic_values[3:]:
+ self.assertAllEqual(v, np.zeros_like(input_value))
+ for v in dynamic_state_values[3:]:
+ self.assertAllEqual(v, np.zeros_like(input_value))
+
+
+class LSTMTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._seed = 23489
+ np.random.seed(self._seed)
+
+ def _testNoProjNoSharding(self, use_gpu):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+ cell = rnn_cell.LSTMCell(
+ num_units, input_size, initializer=initializer)
+ inputs = 10 * [
+ tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+ outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
+ self.assertEqual(len(outputs), len(inputs))
+ for out in outputs:
+ self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
+
+ tf.initialize_all_variables().run()
+ input_value = np.random.randn(batch_size, input_size)
+ sess.run(outputs, feed_dict={inputs[0]: input_value})
+
+ def _testCellClipping(self, use_gpu):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+ cell = rnn_cell.LSTMCell(
+ num_units, input_size, use_peepholes=True,
+ cell_clip=0.0, initializer=initializer)
+ inputs = 10 * [
+ tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+ outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
+ self.assertEqual(len(outputs), len(inputs))
+ for out in outputs:
+ self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
+
+ tf.initialize_all_variables().run()
+ input_value = np.random.randn(batch_size, input_size)
+ values = sess.run(outputs, feed_dict={inputs[0]: input_value})
+
+ for value in values:
+ # if cell c is clipped to 0, tanh(c) = 0 => m==0
+ self.assertAllEqual(value, np.zeros((batch_size, num_units)))
+
+ def _testNoProjNoShardingSimpleStateSaver(self, use_gpu):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+ state_saver = TestStateSaver(batch_size, 2*num_units)
+ cell = rnn_cell.LSTMCell(
+ num_units, input_size, use_peepholes=False, initializer=initializer)
+ inputs = 10 * [
+ tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+ with tf.variable_scope("share_scope"):
+ outputs, states = rnn.state_saving_rnn(
+ cell, inputs, state_saver=state_saver, state_name="save_lstm")
+ self.assertEqual(len(outputs), len(inputs))
+ for out in outputs:
+ self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
+
+ tf.initialize_all_variables().run()
+ input_value = np.random.randn(batch_size, input_size)
+ (last_state_value, saved_state_value) = sess.run(
+ [states[-1], state_saver.saved_state],
+ feed_dict={inputs[0]: input_value})
+ self.assertAllEqual(last_state_value, saved_state_value)
+
+ def _testProjNoSharding(self, use_gpu):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ num_proj = 4
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+ inputs = 10 * [
+ tf.placeholder(tf.float32, shape=(None, input_size))]
+ cell = rnn_cell.LSTMCell(
+ num_units, input_size, use_peepholes=True,
+ num_proj=num_proj, initializer=initializer)
+ outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
+ self.assertEqual(len(outputs), len(inputs))
+
+ tf.initialize_all_variables().run()
+ input_value = np.random.randn(batch_size, input_size)
+ sess.run(outputs, feed_dict={inputs[0]: input_value})
+
+ def _testProjSharding(self, use_gpu):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ num_proj = 4
+ num_proj_shards = 4
+ num_unit_shards = 2
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+
+ inputs = 10 * [
+ tf.placeholder(tf.float32, shape=(None, input_size))]
+
+ cell = rnn_cell.LSTMCell(
+ num_units,
+ input_size=input_size,
+ use_peepholes=True,
+ num_proj=num_proj,
+ num_unit_shards=num_unit_shards,
+ num_proj_shards=num_proj_shards,
+ initializer=initializer)
+
+ outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
+
+ self.assertEqual(len(outputs), len(inputs))
+
+ tf.initialize_all_variables().run()
+ input_value = np.random.randn(batch_size, input_size)
+ sess.run(outputs, feed_dict={inputs[0]: input_value})
+
+ def _testDoubleInput(self, use_gpu):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ num_proj = 4
+ num_proj_shards = 4
+ num_unit_shards = 2
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed)
+ inputs = 10 * [tf.placeholder(tf.float64)]
+
+ cell = rnn_cell.LSTMCell(
+ num_units,
+ input_size=input_size,
+ use_peepholes=True,
+ num_proj=num_proj,
+ num_unit_shards=num_unit_shards,
+ num_proj_shards=num_proj_shards,
+ initializer=initializer)
+
+ outputs, _ = rnn.rnn(
+ cell, inputs, initial_state=cell.zero_state(batch_size, tf.float64))
+
+ self.assertEqual(len(outputs), len(inputs))
+
+ tf.initialize_all_variables().run()
+ input_value = np.asarray(np.random.randn(batch_size, input_size),
+ dtype=np.float64)
+ values = sess.run(outputs, feed_dict={inputs[0]: input_value})
+ self.assertEqual(values[0].dtype, input_value.dtype)
+
+ def _testShardNoShardEquivalentOutput(self, use_gpu):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ num_proj = 4
+ num_proj_shards = 4
+ num_unit_shards = 2
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ inputs = 10 * [tf.placeholder(tf.float32)]
+ initializer = tf.constant_initializer(0.001)
+
+ cell_noshard = rnn_cell.LSTMCell(
+ num_units, input_size,
+ num_proj=num_proj,
+ use_peepholes=True,
+ initializer=initializer,
+ num_unit_shards=num_unit_shards,
+ num_proj_shards=num_proj_shards)
+
+ cell_shard = rnn_cell.LSTMCell(
+ num_units, input_size, use_peepholes=True,
+ initializer=initializer, num_proj=num_proj)
+
+ with tf.variable_scope("noshard_scope"):
+ outputs_noshard, states_noshard = rnn.rnn(
+ cell_noshard, inputs, dtype=tf.float32)
+ with tf.variable_scope("shard_scope"):
+ outputs_shard, states_shard = rnn.rnn(
+ cell_shard, inputs, dtype=tf.float32)
+
+ self.assertEqual(len(outputs_noshard), len(inputs))
+ self.assertEqual(len(outputs_noshard), len(outputs_shard))
+
+ tf.initialize_all_variables().run()
+ input_value = np.random.randn(batch_size, input_size)
+ feeds = dict((x, input_value) for x in inputs)
+ values_noshard = sess.run(outputs_noshard, feed_dict=feeds)
+ values_shard = sess.run(outputs_shard, feed_dict=feeds)
+ state_values_noshard = sess.run(states_noshard, feed_dict=feeds)
+ state_values_shard = sess.run(states_shard, feed_dict=feeds)
+ self.assertEqual(len(values_noshard), len(values_shard))
+ self.assertEqual(len(state_values_noshard), len(state_values_shard))
+ for (v_noshard, v_shard) in zip(values_noshard, values_shard):
+ self.assertAllClose(v_noshard, v_shard, atol=1e-3)
+ for (s_noshard, s_shard) in zip(state_values_noshard, state_values_shard):
+ self.assertAllClose(s_noshard, s_shard, atol=1e-3)
+
+ def _testDoubleInputWithDropoutAndDynamicCalculation(
+ self, use_gpu):
+ """Smoke test for using LSTM with doubles, dropout, dynamic calculation."""
+
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ num_proj = 4
+ num_proj_shards = 4
+ num_unit_shards = 2
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ sequence_length = tf.placeholder(tf.int64)
+ initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+ inputs = 10 * [tf.placeholder(tf.float64)]
+
+ cell = rnn_cell.LSTMCell(
+ num_units,
+ input_size=input_size,
+ use_peepholes=True,
+ num_proj=num_proj,
+ num_unit_shards=num_unit_shards,
+ num_proj_shards=num_proj_shards,
+ initializer=initializer)
+ dropout_cell = rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
+
+ outputs, states = rnn.rnn(
+ dropout_cell, inputs, sequence_length=sequence_length,
+ initial_state=cell.zero_state(batch_size, tf.float64))
+
+ self.assertEqual(len(outputs), len(inputs))
+ self.assertEqual(len(outputs), len(states))
+
+ tf.initialize_all_variables().run()
+ input_value = np.asarray(np.random.randn(batch_size, input_size),
+ dtype=np.float64)
+ values = sess.run(outputs, feed_dict={inputs[0]: input_value,
+ sequence_length: [2, 3]})
+ state_values = sess.run(states, feed_dict={inputs[0]: input_value,
+ sequence_length: [2, 3]})
+ self.assertEqual(values[0].dtype, input_value.dtype)
+ self.assertEqual(state_values[0].dtype, input_value.dtype)
+
+ def testSharingWeightsWithReuse(self):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ num_proj = 4
+ with self.test_session(graph=tf.Graph()) as sess:
+ initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed)
+ inputs = 10 * [
+ tf.placeholder(tf.float32, shape=(None, input_size))]
+ cell = rnn_cell.LSTMCell(
+ num_units, input_size, use_peepholes=True,
+ num_proj=num_proj, initializer=initializer)
+
+ with tf.variable_scope("share_scope"):
+ outputs0, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
+ with tf.variable_scope("share_scope", reuse=True):
+ outputs1, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
+ with tf.variable_scope("diff_scope"):
+ outputs2, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
+
+ tf.initialize_all_variables().run()
+ input_value = np.random.randn(batch_size, input_size)
+ output_values = sess.run(
+ outputs0 + outputs1 + outputs2, feed_dict={inputs[0]: input_value})
+ outputs0_values = output_values[:10]
+ outputs1_values = output_values[10:20]
+ outputs2_values = output_values[20:]
+ self.assertEqual(len(outputs0_values), len(outputs1_values))
+ self.assertEqual(len(outputs0_values), len(outputs2_values))
+ for o1, o2, o3 in zip(outputs0_values, outputs1_values, outputs2_values):
+ # Same weights used by both RNNs so outputs should be the same.
+ self.assertAllEqual(o1, o2)
+ # Different weights used so outputs should be different.
+ self.assertTrue(np.linalg.norm(o1-o3) > 1e-6)
+
+ def testSharingWeightsWithDifferentNamescope(self):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ num_proj = 4
+ with self.test_session(graph=tf.Graph()) as sess:
+ initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed)
+ inputs = 10 * [
+ tf.placeholder(tf.float32, shape=(None, input_size))]
+ cell = rnn_cell.LSTMCell(
+ num_units, input_size, use_peepholes=True,
+ num_proj=num_proj, initializer=initializer)
+
+ with tf.name_scope("scope0"):
+ with tf.variable_scope("share_scope"):
+ outputs0, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
+ with tf.name_scope("scope1"):
+ with tf.variable_scope("share_scope", reuse=True):
+ outputs1, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
+
+ tf.initialize_all_variables().run()
+ input_value = np.random.randn(batch_size, input_size)
+ output_values = sess.run(
+ outputs0 + outputs1, feed_dict={inputs[0]: input_value})
+ outputs0_values = output_values[:10]
+ outputs1_values = output_values[10:]
+ self.assertEqual(len(outputs0_values), len(outputs1_values))
+ for out0, out1 in zip(outputs0_values, outputs1_values):
+ self.assertAllEqual(out0, out1)
+
+ def testNoProjNoShardingSimpleStateSaver(self):
+ self._testNoProjNoShardingSimpleStateSaver(False)
+ self._testNoProjNoShardingSimpleStateSaver(True)
+
+ def testNoProjNoSharding(self):
+ self._testNoProjNoSharding(False)
+ self._testNoProjNoSharding(True)
+
+ def testCellClipping(self):
+ self._testCellClipping(False)
+ self._testCellClipping(True)
+
+ def testProjNoSharding(self):
+ self._testProjNoSharding(False)
+ self._testProjNoSharding(True)
+
+ def testProjSharding(self):
+ self._testProjSharding(False)
+ self._testProjSharding(True)
+
+ def testShardNoShardEquivalentOutput(self):
+ self._testShardNoShardEquivalentOutput(False)
+ self._testShardNoShardEquivalentOutput(True)
+
+ def testDoubleInput(self):
+ self._testDoubleInput(False)
+ self._testDoubleInput(True)
+
+ def testDoubleInputWithDropoutAndDynamicCalculation(self):
+ self._testDoubleInputWithDropoutAndDynamicCalculation(False)
+ self._testDoubleInputWithDropoutAndDynamicCalculation(True)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/models/rnn/seq2seq.py b/tensorflow/models/rnn/seq2seq.py
new file mode 100644
index 0000000000..a3b6a838ca
--- /dev/null
+++ b/tensorflow/models/rnn/seq2seq.py
@@ -0,0 +1,749 @@
+"""Library for creating sequence-to-sequence models."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from tensorflow.models.rnn import linear
+from tensorflow.models.rnn import rnn
+from tensorflow.models.rnn import rnn_cell
+
+
+def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None,
+ scope=None):
+ """RNN decoder for the sequence-to-sequence model.
+
+ Args:
+ decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ initial_state: 2D Tensor with shape [batch_size x cell.state_size].
+ cell: rnn_cell.RNNCell defining the cell function and size.
+ loop_function: if not None, this function will be applied to i-th output
+ in order to generate i+1-th input, and decoder_inputs will be ignored,
+ except for the first element ("GO" symbol). This can be used for decoding,
+ but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf.
+ Signature -- loop_function(prev, i) = next
+ * prev is a 2D Tensor of shape [batch_size x cell.output_size],
+ * i is an integer, the step number (when advanced control is needed),
+ * next is a 2D Tensor of shape [batch_size x cell.input_size].
+ scope: VariableScope for the created subgraph; defaults to "rnn_decoder".
+
+ Returns:
+ outputs: A list of the same length as decoder_inputs of 2D Tensors with
+ shape [batch_size x cell.output_size] containing generated outputs.
+ states: The state of each cell in each time-step. This is a list with
+ length len(decoder_inputs) -- one item for each time-step.
+ Each item is a 2D Tensor of shape [batch_size x cell.state_size].
+ (Note that in some cases, like basic RNN cell or GRU cell, outputs and
+ states can be the same. They are different for LSTM cells though.)
+ """
+ with tf.variable_scope(scope or "rnn_decoder"):
+ states = [initial_state]
+ outputs = []
+ prev = None
+ for i in xrange(len(decoder_inputs)):
+ inp = decoder_inputs[i]
+ if loop_function is not None and prev is not None:
+ with tf.variable_scope("loop_function", reuse=True):
+ # We do not propagate gradients over the loop function.
+ inp = tf.stop_gradient(loop_function(prev, i))
+ if i > 0:
+ tf.get_variable_scope().reuse_variables()
+ output, new_state = cell(inp, states[-1])
+ outputs.append(output)
+ states.append(new_state)
+ if loop_function is not None:
+ prev = tf.stop_gradient(output)
+ return outputs, states
+
+
+def basic_rnn_seq2seq(
+ encoder_inputs, decoder_inputs, cell, dtype=tf.float32, scope=None):
+ """Basic RNN sequence-to-sequence model.
+
+ This model first runs an RNN to encode encoder_inputs into a state vector, and
+ then runs decoder, initialized with the last encoder state, on decoder_inputs.
+ Encoder and decoder use the same RNN cell type, but don't share parameters.
+
+ Args:
+ encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ cell: rnn_cell.RNNCell defining the cell function and size.
+ dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
+ scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
+
+ Returns:
+ outputs: A list of the same length as decoder_inputs of 2D Tensors with
+ shape [batch_size x cell.output_size] containing the generated outputs.
+ states: The state of each decoder cell in each time-step. This is a list
+ with length len(decoder_inputs) -- one item for each time-step.
+ Each item is a 2D Tensor of shape [batch_size x cell.state_size].
+ """
+ with tf.variable_scope(scope or "basic_rnn_seq2seq"):
+ _, enc_states = rnn.rnn(cell, encoder_inputs, dtype=dtype)
+ return rnn_decoder(decoder_inputs, enc_states[-1], cell)
+
+
+def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
+ loop_function=None, dtype=tf.float32, scope=None):
+ """RNN sequence-to-sequence model with tied encoder and decoder parameters.
+
+ This model first runs an RNN to encode encoder_inputs into a state vector, and
+ then runs decoder, initialized with the last encoder state, on decoder_inputs.
+ Encoder and decoder use the same RNN cell and share parameters.
+
+ Args:
+ encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ cell: rnn_cell.RNNCell defining the cell function and size.
+ loop_function: if not None, this function will be applied to i-th output
+ in order to generate i+1-th input, and decoder_inputs will be ignored,
+ except for the first element ("GO" symbol), see rnn_decoder for details.
+ dtype: The dtype of the initial state of the rnn cell (default: tf.float32).
+ scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq".
+
+ Returns:
+ outputs: A list of the same length as decoder_inputs of 2D Tensors with
+ shape [batch_size x cell.output_size] containing the generated outputs.
+ states: The state of each decoder cell in each time-step. This is a list
+ with length len(decoder_inputs) -- one item for each time-step.
+ Each item is a 2D Tensor of shape [batch_size x cell.state_size].
+ """
+ with tf.variable_scope("combined_tied_rnn_seq2seq"):
+ scope = scope or "tied_rnn_seq2seq"
+ _, enc_states = rnn.rnn(
+ cell, encoder_inputs, dtype=dtype, scope=scope)
+ tf.get_variable_scope().reuse_variables()
+ return rnn_decoder(decoder_inputs, enc_states[-1], cell,
+ loop_function=loop_function, scope=scope)
+
+
+def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
+ output_projection=None, feed_previous=False,
+ scope=None):
+ """RNN decoder with embedding and a pure-decoding option.
+
+ Args:
+ decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs).
+ initial_state: 2D Tensor [batch_size x cell.state_size].
+ cell: rnn_cell.RNNCell defining the cell function.
+ num_symbols: integer, how many symbols come into the embedding.
+ output_projection: None or a pair (W, B) of output projection weights and
+ biases; W has shape [cell.output_size x num_symbols] and B has
+ shape [num_symbols]; if provided and feed_previous=True, each fed
+ previous output will first be multiplied by W and added B.
+ feed_previous: Boolean; if True, only the first of decoder_inputs will be
+ used (the "GO" symbol), and all other decoder inputs will be generated by:
+ next = embedding_lookup(embedding, argmax(previous_output)),
+ In effect, this implements a greedy decoder. It can also be used
+ during training to emulate http://arxiv.org/pdf/1506.03099v2.pdf.
+ If False, decoder_inputs are used as given (the standard decoder case).
+ scope: VariableScope for the created subgraph; defaults to
+ "embedding_rnn_decoder".
+
+ Returns:
+ outputs: A list of the same length as decoder_inputs of 2D Tensors with
+ shape [batch_size x cell.output_size] containing the generated outputs.
+ states: The state of each decoder cell in each time-step. This is a list
+ with length len(decoder_inputs) -- one item for each time-step.
+ Each item is a 2D Tensor of shape [batch_size x cell.state_size].
+
+ Raises:
+ ValueError: when output_projection has the wrong shape.
+ """
+ if output_projection is not None:
+ proj_weights = tf.convert_to_tensor(output_projection[0], dtype=tf.float32)
+ proj_weights.get_shape().assert_is_compatible_with([cell.output_size,
+ num_symbols])
+ proj_biases = tf.convert_to_tensor(output_projection[1], dtype=tf.float32)
+ proj_biases.get_shape().assert_is_compatible_with([num_symbols])
+
+ with tf.variable_scope(scope or "embedding_rnn_decoder"):
+ with tf.device("/cpu:0"):
+ embedding = tf.get_variable("embedding", [num_symbols, cell.input_size])
+
+ def extract_argmax_and_embed(prev, _):
+ """Loop_function that extracts the symbol from prev and embeds it."""
+ if output_projection is not None:
+ prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
+ prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
+ return tf.nn.embedding_lookup(embedding, prev_symbol)
+
+ loop_function = None
+ if feed_previous:
+ loop_function = extract_argmax_and_embed
+
+ emb_inp = [tf.nn.embedding_lookup(embedding, i) for i in decoder_inputs]
+ return rnn_decoder(emb_inp, initial_state, cell,
+ loop_function=loop_function)
+
+
+def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
+ num_encoder_symbols, num_decoder_symbols,
+ output_projection=None, feed_previous=False,
+ dtype=tf.float32, scope=None):
+ """Embedding RNN sequence-to-sequence model.
+
+ This model first embeds encoder_inputs by a newly created embedding (of shape
+ [num_encoder_symbols x cell.input_size]). Then it runs an RNN to encode
+ embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs
+ by another newly created embedding (of shape [num_decoder_symbols x
+ cell.input_size]). Then it runs RNN decoder, initialized with the last
+ encoder state, on embedded decoder_inputs.
+
+ Args:
+ encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ cell: rnn_cell.RNNCell defining the cell function and size.
+ num_encoder_symbols: integer; number of symbols on the encoder side.
+ num_decoder_symbols: integer; number of symbols on the decoder side.
+ output_projection: None or a pair (W, B) of output projection weights and
+ biases; W has shape [cell.output_size x num_decoder_symbols] and B has
+ shape [num_decoder_symbols]; if provided and feed_previous=True, each
+ fed previous output will first be multiplied by W and added B.
+ feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
+ of decoder_inputs will be used (the "GO" symbol), and all other decoder
+ inputs will be taken from previous outputs (as in embedding_rnn_decoder).
+ If False, decoder_inputs are used as given (the standard decoder case).
+ dtype: The dtype of the initial state for both the encoder and encoder
+ rnn cells (default: tf.float32).
+ scope: VariableScope for the created subgraph; defaults to
+ "embedding_rnn_seq2seq"
+
+ Returns:
+ outputs: A list of the same length as decoder_inputs of 2D Tensors with
+ shape [batch_size x num_decoder_symbols] containing the generated outputs.
+ states: The state of each decoder cell in each time-step. This is a list
+ with length len(decoder_inputs) -- one item for each time-step.
+ Each item is a 2D Tensor of shape [batch_size x cell.state_size].
+ """
+ with tf.variable_scope(scope or "embedding_rnn_seq2seq"):
+ # Encoder.
+ encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
+ _, encoder_states = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
+
+ # Decoder.
+ if output_projection is None:
+ cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
+
+ if isinstance(feed_previous, bool):
+ return embedding_rnn_decoder(decoder_inputs, encoder_states[-1], cell,
+ num_decoder_symbols, output_projection,
+ feed_previous)
+ else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
+ outputs1, states1 = embedding_rnn_decoder(
+ decoder_inputs, encoder_states[-1], cell, num_decoder_symbols,
+ output_projection, True)
+ tf.get_variable_scope().reuse_variables()
+ outputs2, states2 = embedding_rnn_decoder(
+ decoder_inputs, encoder_states[-1], cell, num_decoder_symbols,
+ output_projection, False)
+
+ outputs = tf.control_flow_ops.cond(feed_previous,
+ lambda: outputs1, lambda: outputs2)
+ states = tf.control_flow_ops.cond(feed_previous,
+ lambda: states1, lambda: states2)
+ return outputs, states
+
+
+def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
+ num_symbols, output_projection=None,
+ feed_previous=False, dtype=tf.float32,
+ scope=None):
+ """Embedding RNN sequence-to-sequence model with tied (shared) parameters.
+
+ This model first embeds encoder_inputs by a newly created embedding (of shape
+ [num_symbols x cell.input_size]). Then it runs an RNN to encode embedded
+ encoder_inputs into a state vector. Next, it embeds decoder_inputs using
+ the same embedding. Then it runs RNN decoder, initialized with the last
+ encoder state, on embedded decoder_inputs.
+
+ Args:
+ encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ cell: rnn_cell.RNNCell defining the cell function and size.
+ num_symbols: integer; number of symbols for both encoder and decoder.
+ output_projection: None or a pair (W, B) of output projection weights and
+ biases; W has shape [cell.output_size x num_symbols] and B has
+ shape [num_symbols]; if provided and feed_previous=True, each
+ fed previous output will first be multiplied by W and added B.
+ feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
+ of decoder_inputs will be used (the "GO" symbol), and all other decoder
+ inputs will be taken from previous outputs (as in embedding_rnn_decoder).
+ If False, decoder_inputs are used as given (the standard decoder case).
+ dtype: The dtype to use for the initial RNN states (default: tf.float32).
+ scope: VariableScope for the created subgraph; defaults to
+ "embedding_tied_rnn_seq2seq".
+
+ Returns:
+ outputs: A list of the same length as decoder_inputs of 2D Tensors with
+ shape [batch_size x num_decoder_symbols] containing the generated outputs.
+ states: The state of each decoder cell in each time-step. This is a list
+ with length len(decoder_inputs) -- one item for each time-step.
+ Each item is a 2D Tensor of shape [batch_size x cell.state_size].
+
+ Raises:
+ ValueError: when output_projection has the wrong shape.
+ """
+ if output_projection is not None:
+ proj_weights = tf.convert_to_tensor(output_projection[0], dtype=dtype)
+ proj_weights.get_shape().assert_is_compatible_with([cell.output_size,
+ num_symbols])
+ proj_biases = tf.convert_to_tensor(output_projection[1], dtype=dtype)
+ proj_biases.get_shape().assert_is_compatible_with([num_symbols])
+
+ with tf.variable_scope(scope or "embedding_tied_rnn_seq2seq"):
+ with tf.device("/cpu:0"):
+ embedding = tf.get_variable("embedding", [num_symbols, cell.input_size])
+
+ emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x)
+ for x in encoder_inputs]
+ emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x)
+ for x in decoder_inputs]
+
+ def extract_argmax_and_embed(prev, _):
+ """Loop_function that extracts the symbol from prev and embeds it."""
+ if output_projection is not None:
+ prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
+ prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
+ return tf.nn.embedding_lookup(embedding, prev_symbol)
+
+ if output_projection is None:
+ cell = rnn_cell.OutputProjectionWrapper(cell, num_symbols)
+
+ if isinstance(feed_previous, bool):
+ loop_function = extract_argmax_and_embed if feed_previous else None
+ return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell,
+ loop_function=loop_function, dtype=dtype)
+ else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
+ outputs1, states1 = tied_rnn_seq2seq(
+ emb_encoder_inputs, emb_decoder_inputs, cell,
+ loop_function=extract_argmax_and_embed, dtype=dtype)
+ tf.get_variable_scope().reuse_variables()
+ outputs2, states2 = tied_rnn_seq2seq(
+ emb_encoder_inputs, emb_decoder_inputs, cell, dtype=dtype)
+
+ outputs = tf.control_flow_ops.cond(feed_previous,
+ lambda: outputs1, lambda: outputs2)
+ states = tf.control_flow_ops.cond(feed_previous,
+ lambda: states1, lambda: states2)
+ return outputs, states
+
+
+def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
+ output_size=None, num_heads=1, loop_function=None,
+ dtype=tf.float32, scope=None):
+ """RNN decoder with attention for the sequence-to-sequence model.
+
+ Args:
+ decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ initial_state: 2D Tensor [batch_size x cell.state_size].
+ attention_states: 3D Tensor [batch_size x attn_length x attn_size].
+ cell: rnn_cell.RNNCell defining the cell function and size.
+ output_size: size of the output vectors; if None, we use cell.output_size.
+ num_heads: number of attention heads that read from attention_states.
+ loop_function: if not None, this function will be applied to i-th output
+ in order to generate i+1-th input, and decoder_inputs will be ignored,
+ except for the first element ("GO" symbol). This can be used for decoding,
+ but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf.
+ Signature -- loop_function(prev, i) = next
+ * prev is a 2D Tensor of shape [batch_size x cell.output_size],
+ * i is an integer, the step number (when advanced control is needed),
+ * next is a 2D Tensor of shape [batch_size x cell.input_size].
+ dtype: The dtype to use for the RNN initial state (default: tf.float32).
+ scope: VariableScope for the created subgraph; default: "attention_decoder".
+
+ Returns:
+ outputs: A list of the same length as decoder_inputs of 2D Tensors of shape
+ [batch_size x output_size]. These represent the generated outputs.
+ Output i is computed from input i (which is either i-th decoder_inputs or
+ loop_function(output {i-1}, i)) as follows. First, we run the cell
+ on a combination of the input and previous attention masks:
+ cell_output, new_state = cell(linear(input, prev_attn), prev_state).
+ Then, we calculate new attention masks:
+ new_attn = softmax(V^T * tanh(W * attention_states + U * new_state))
+ and then we calculate the output:
+ output = linear(cell_output, new_attn).
+ states: The state of each decoder cell in each time-step. This is a list
+ with length len(decoder_inputs) -- one item for each time-step.
+ Each item is a 2D Tensor of shape [batch_size x cell.state_size].
+
+ Raises:
+ ValueError: when num_heads is not positive, there are no inputs, or shapes
+ of attention_states are not set.
+ """
+ if not decoder_inputs:
+ raise ValueError("Must provide at least 1 input to attention decoder.")
+ if num_heads < 1:
+ raise ValueError("With less than 1 heads, use a non-attention decoder.")
+ if not attention_states.get_shape()[1:2].is_fully_defined():
+ raise ValueError("Shape[1] and [2] of attention_states must be known: %s"
+ % attention_states.get_shape())
+ if output_size is None:
+ output_size = cell.output_size
+
+ with tf.variable_scope(scope or "attention_decoder"):
+ batch_size = tf.shape(decoder_inputs[0])[0] # Needed for reshaping.
+ attn_length = attention_states.get_shape()[1].value
+ attn_size = attention_states.get_shape()[2].value
+
+ # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
+ hidden = tf.reshape(attention_states, [-1, attn_length, 1, attn_size])
+ hidden_features = []
+ v = []
+ attention_vec_size = attn_size # Size of query vectors for attention.
+ for a in xrange(num_heads):
+ k = tf.get_variable("AttnW_%d" % a, [1, 1, attn_size, attention_vec_size])
+ hidden_features.append(tf.nn.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
+ v.append(tf.get_variable("AttnV_%d" % a, [attention_vec_size]))
+
+ states = [initial_state]
+
+ def attention(query):
+ """Put attention masks on hidden using hidden_features and query."""
+ ds = [] # Results of attention reads will be stored here.
+ for a in xrange(num_heads):
+ with tf.variable_scope("Attention_%d" % a):
+ y = linear.linear(query, attention_vec_size, True)
+ y = tf.reshape(y, [-1, 1, 1, attention_vec_size])
+ # Attention mask is a softmax of v^T * tanh(...).
+ s = tf.reduce_sum(v[a] * tf.tanh(hidden_features[a] + y), [2, 3])
+ a = tf.nn.softmax(s)
+ # Now calculate the attention-weighted vector d.
+ d = tf.reduce_sum(tf.reshape(a, [-1, attn_length, 1, 1]) * hidden,
+ [1, 2])
+ ds.append(tf.reshape(d, [-1, attn_size]))
+ return ds
+
+ outputs = []
+ prev = None
+ batch_attn_size = tf.pack([batch_size, attn_size])
+ attns = [tf.zeros(batch_attn_size, dtype=dtype)
+ for _ in xrange(num_heads)]
+ for a in attns: # Ensure the second shape of attention vectors is set.
+ a.set_shape([None, attn_size])
+ for i in xrange(len(decoder_inputs)):
+ if i > 0:
+ tf.get_variable_scope().reuse_variables()
+ inp = decoder_inputs[i]
+ # If loop_function is set, we use it instead of decoder_inputs.
+ if loop_function is not None and prev is not None:
+ with tf.variable_scope("loop_function", reuse=True):
+ inp = tf.stop_gradient(loop_function(prev, i))
+ # Merge input and previous attentions into one vector of the right size.
+ x = linear.linear([inp] + attns, cell.input_size, True)
+ # Run the RNN.
+ cell_output, new_state = cell(x, states[-1])
+ states.append(new_state)
+ # Run the attention mechanism.
+ attns = attention(new_state)
+ with tf.variable_scope("AttnOutputProjection"):
+ output = linear.linear([cell_output] + attns, output_size, True)
+ if loop_function is not None:
+ # We do not propagate gradients over the loop function.
+ prev = tf.stop_gradient(output)
+ outputs.append(output)
+
+ return outputs, states
+
+
+def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
+ cell, num_symbols, num_heads=1,
+ output_size=None, output_projection=None,
+ feed_previous=False, dtype=tf.float32,
+ scope=None):
+ """RNN decoder with embedding and attention and a pure-decoding option.
+
+ Args:
+ decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs).
+ initial_state: 2D Tensor [batch_size x cell.state_size].
+ attention_states: 3D Tensor [batch_size x attn_length x attn_size].
+ cell: rnn_cell.RNNCell defining the cell function.
+ num_symbols: integer, how many symbols come into the embedding.
+ num_heads: number of attention heads that read from attention_states.
+ output_size: size of the output vectors; if None, use cell.output_size.
+ output_projection: None or a pair (W, B) of output projection weights and
+ biases; W has shape [output_size x num_symbols] and B has shape
+ [num_symbols]; if provided and feed_previous=True, each fed previous
+ output will first be multiplied by W and added B.
+ feed_previous: Boolean; if True, only the first of decoder_inputs will be
+ used (the "GO" symbol), and all other decoder inputs will be generated by:
+ next = embedding_lookup(embedding, argmax(previous_output)),
+ In effect, this implements a greedy decoder. It can also be used
+ during training to emulate http://arxiv.org/pdf/1506.03099v2.pdf.
+ If False, decoder_inputs are used as given (the standard decoder case).
+ dtype: The dtype to use for the RNN initial states (default: tf.float32).
+ scope: VariableScope for the created subgraph; defaults to
+ "embedding_attention_decoder".
+
+ Returns:
+ outputs: A list of the same length as decoder_inputs of 2D Tensors with
+ shape [batch_size x output_size] containing the generated outputs.
+ states: The state of each decoder cell in each time-step. This is a list
+ with length len(decoder_inputs) -- one item for each time-step.
+ Each item is a 2D Tensor of shape [batch_size x cell.state_size].
+
+ Raises:
+ ValueError: when output_projection has the wrong shape.
+ """
+ if output_size is None:
+ output_size = cell.output_size
+ if output_projection is not None:
+ proj_weights = tf.convert_to_tensor(output_projection[0], dtype=dtype)
+ proj_weights.get_shape().assert_is_compatible_with([cell.output_size,
+ num_symbols])
+ proj_biases = tf.convert_to_tensor(output_projection[1], dtype=dtype)
+ proj_biases.get_shape().assert_is_compatible_with([num_symbols])
+
+ with tf.variable_scope(scope or "embedding_attention_decoder"):
+ with tf.device("/cpu:0"):
+ embedding = tf.get_variable("embedding", [num_symbols, cell.input_size])
+
+ def extract_argmax_and_embed(prev, _):
+ """Loop_function that extracts the symbol from prev and embeds it."""
+ if output_projection is not None:
+ prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
+ prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
+ emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol)
+ return emb_prev
+
+ loop_function = None
+ if feed_previous:
+ loop_function = extract_argmax_and_embed
+
+ emb_inp = [tf.nn.embedding_lookup(embedding, i) for i in decoder_inputs]
+ return attention_decoder(
+ emb_inp, initial_state, attention_states, cell, output_size=output_size,
+ num_heads=num_heads, loop_function=loop_function)
+
+
+def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell,
+ num_encoder_symbols, num_decoder_symbols,
+ num_heads=1, output_projection=None,
+ feed_previous=False, dtype=tf.float32,
+ scope=None):
+ """Embedding sequence-to-sequence model with attention.
+
+ This model first embeds encoder_inputs by a newly created embedding (of shape
+ [num_encoder_symbols x cell.input_size]). Then it runs an RNN to encode
+ embedded encoder_inputs into a state vector. It keeps the outputs of this
+ RNN at every step to use for attention later. Next, it embeds decoder_inputs
+ by another newly created embedding (of shape [num_decoder_symbols x
+ cell.input_size]). Then it runs attention decoder, initialized with the last
+ encoder state, on embedded decoder_inputs and attending to encoder outputs.
+
+ Args:
+ encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
+ cell: rnn_cell.RNNCell defining the cell function and size.
+ num_encoder_symbols: integer; number of symbols on the encoder side.
+ num_decoder_symbols: integer; number of symbols on the decoder side.
+ num_heads: number of attention heads that read from attention_states.
+ output_projection: None or a pair (W, B) of output projection weights and
+ biases; W has shape [cell.output_size x num_decoder_symbols] and B has
+ shape [num_decoder_symbols]; if provided and feed_previous=True, each
+ fed previous output will first be multiplied by W and added B.
+ feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
+ of decoder_inputs will be used (the "GO" symbol), and all other decoder
+ inputs will be taken from previous outputs (as in embedding_rnn_decoder).
+ If False, decoder_inputs are used as given (the standard decoder case).
+ dtype: The dtype of the initial RNN state (default: tf.float32).
+ scope: VariableScope for the created subgraph; defaults to
+ "embedding_attention_seq2seq".
+
+ Returns:
+ outputs: A list of the same length as decoder_inputs of 2D Tensors with
+ shape [batch_size x num_decoder_symbols] containing the generated outputs.
+ states: The state of each decoder cell in each time-step. This is a list
+ with length len(decoder_inputs) -- one item for each time-step.
+ Each item is a 2D Tensor of shape [batch_size x cell.state_size].
+ """
+ with tf.variable_scope(scope or "embedding_attention_seq2seq"):
+ # Encoder.
+ encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
+ encoder_outputs, encoder_states = rnn.rnn(
+ encoder_cell, encoder_inputs, dtype=dtype)
+
+ # First calculate a concatenation of encoder outputs to put attention on.
+ top_states = [tf.reshape(e, [-1, 1, cell.output_size])
+ for e in encoder_outputs]
+ attention_states = tf.concat(1, top_states)
+
+ # Decoder.
+ output_size = None
+ if output_projection is None:
+ cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
+ output_size = num_decoder_symbols
+
+ if isinstance(feed_previous, bool):
+ return embedding_attention_decoder(
+ decoder_inputs, encoder_states[-1], attention_states, cell,
+ num_decoder_symbols, num_heads, output_size, output_projection,
+ feed_previous)
+ else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
+ outputs1, states1 = embedding_attention_decoder(
+ decoder_inputs, encoder_states[-1], attention_states, cell,
+ num_decoder_symbols, num_heads, output_size, output_projection, True)
+ tf.get_variable_scope().reuse_variables()
+ outputs2, states2 = embedding_attention_decoder(
+ decoder_inputs, encoder_states[-1], attention_states, cell,
+ num_decoder_symbols, num_heads, output_size, output_projection, False)
+
+ outputs = tf.control_flow_ops.cond(feed_previous,
+ lambda: outputs1, lambda: outputs2)
+ states = tf.control_flow_ops.cond(feed_previous,
+ lambda: states1, lambda: states2)
+ return outputs, states
+
+
+def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols,
+ average_across_timesteps=True,
+ softmax_loss_function=None, name=None):
+ """Weighted cross-entropy loss for a sequence of logits (per example).
+
+ Args:
+ logits: list of 2D Tensors of shape [batch_size x num_decoder_symbols].
+ targets: list of 1D batch-sized int32-Tensors of the same length as logits.
+ weights: list of 1D batch-sized float-Tensors of the same length as logits.
+ num_decoder_symbols: integer, number of decoder symbols (output classes).
+ average_across_timesteps: If set, divide the returned cost by the total
+ label weight.
+ softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch
+ to be used instead of the standard softmax (the default if this is None).
+ name: optional name for this operation, default: "sequence_loss_by_example".
+
+ Returns:
+ 1D batch-sized float Tensor: the log-perplexity for each sequence.
+
+ Raises:
+ ValueError: if len(logits) is different from len(targets) or len(weights).
+ """
+ if len(targets) != len(logits) or len(weights) != len(logits):
+ raise ValueError("Lengths of logits, weights, and targets must be the same "
+ "%d, %d, %d." % (len(logits), len(weights), len(targets)))
+ with tf.op_scope(logits + targets + weights, name,
+ "sequence_loss_by_example"):
+ batch_size = tf.shape(targets[0])[0]
+ log_perp_list = []
+ length = batch_size * num_decoder_symbols
+ for i in xrange(len(logits)):
+ if softmax_loss_function is None:
+ # TODO(lukaszkaiser): There is no SparseCrossEntropy in TensorFlow, so
+ # we need to first cast targets into a dense representation, and as
+ # SparseToDense does not accept batched inputs, we need to do this by
+ # re-indexing and re-sizing. When TensorFlow adds SparseCrossEntropy,
+ # rewrite this method.
+ indices = targets[i] + num_decoder_symbols * tf.range(0, batch_size)
+ with tf.device("/cpu:0"): # Sparse-to-dense must happen on CPU for now.
+ dense = tf.sparse_to_dense(indices, tf.expand_dims(length, 0), 1.0,
+ 0.0)
+ target = tf.reshape(dense, [-1, num_decoder_symbols])
+ crossent = tf.nn.softmax_cross_entropy_with_logits(
+ logits[i], target, name="SequenceLoss/CrossEntropy{0}".format(i))
+ else:
+ crossent = softmax_loss_function(logits[i], targets[i])
+ log_perp_list.append(crossent * weights[i])
+ log_perps = tf.add_n(log_perp_list)
+ if average_across_timesteps:
+ total_size = tf.add_n(weights)
+ total_size += 1e-12 # Just to avoid division by 0 for all-0 weights.
+ log_perps /= total_size
+ return log_perps
+
+
+def sequence_loss(logits, targets, weights, num_decoder_symbols,
+ average_across_timesteps=True, average_across_batch=True,
+ softmax_loss_function=None, name=None):
+ """Weighted cross-entropy loss for a sequence of logits, batch-collapsed.
+
+ Args:
+ logits: list of 2D Tensors os shape [batch_size x num_decoder_symbols].
+ targets: list of 1D batch-sized int32-Tensors of the same length as logits.
+ weights: list of 1D batch-sized float-Tensors of the same length as logits.
+ num_decoder_symbols: integer, number of decoder symbols (output classes).
+ average_across_timesteps: If set, divide the returned cost by the total
+ label weight.
+ average_across_batch: If set, divide the returned cost by the batch size.
+ softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch
+ to be used instead of the standard softmax (the default if this is None).
+ name: optional name for this operation, defaults to "sequence_loss".
+
+ Returns:
+ A scalar float Tensor: the average log-perplexity per symbol (weighted).
+
+ Raises:
+ ValueError: if len(logits) is different from len(targets) or len(weights).
+ """
+ with tf.op_scope(logits + targets + weights, name, "sequence_loss"):
+ cost = tf.reduce_sum(sequence_loss_by_example(
+ logits, targets, weights, num_decoder_symbols,
+ average_across_timesteps=average_across_timesteps,
+ softmax_loss_function=softmax_loss_function))
+ if average_across_batch:
+ batch_size = tf.shape(targets[0])[0]
+ return cost / tf.cast(batch_size, tf.float32)
+ else:
+ return cost
+
+
+def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights,
+ buckets, num_decoder_symbols, seq2seq,
+ softmax_loss_function=None, name=None):
+ """Create a sequence-to-sequence model with support for bucketing.
+
+ The seq2seq argument is a function that defines a sequence-to-sequence model,
+ e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24))
+
+ Args:
+ encoder_inputs: a list of Tensors to feed the encoder; first seq2seq input.
+ decoder_inputs: a list of Tensors to feed the decoder; second seq2seq input.
+ targets: a list of 1D batch-sized int32-Tensors (desired output sequence).
+ weights: list of 1D batch-sized float-Tensors to weight the targets.
+ buckets: a list of pairs of (input size, output size) for each bucket.
+ num_decoder_symbols: integer, number of decoder symbols (output classes).
+ seq2seq: a sequence-to-sequence model function; it takes 2 input that
+ agree with encoder_inputs and decoder_inputs, and returns a pair
+ consisting of outputs and states (as, e.g., basic_rnn_seq2seq).
+ softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch
+ to be used instead of the standard softmax (the default if this is None).
+ name: optional name for this operation, defaults to "model_with_buckets".
+
+ Returns:
+ outputs: The outputs for each bucket. Its j'th element consists of a list
+ of 2D Tensors of shape [batch_size x num_decoder_symbols] (j'th outputs).
+ losses: List of scalar Tensors, representing losses for each bucket.
+ Raises:
+ ValueError: if length of encoder_inputsut, targets, or weights is smaller
+ than the largest (last) bucket.
+ """
+ if len(encoder_inputs) < buckets[-1][0]:
+ raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
+ "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
+ if len(targets) < buckets[-1][1]:
+ raise ValueError("Length of targets (%d) must be at least that of last"
+ "bucket (%d)." % (len(targets), buckets[-1][1]))
+ if len(weights) < buckets[-1][1]:
+ raise ValueError("Length of weights (%d) must be at least that of last"
+ "bucket (%d)." % (len(weights), buckets[-1][1]))
+
+ all_inputs = encoder_inputs + decoder_inputs + targets + weights
+ losses = []
+ outputs = []
+ with tf.op_scope(all_inputs, name, "model_with_buckets"):
+ for j in xrange(len(buckets)):
+ if j > 0:
+ tf.get_variable_scope().reuse_variables()
+ bucket_encoder_inputs = [encoder_inputs[i]
+ for i in xrange(buckets[j][0])]
+ bucket_decoder_inputs = [decoder_inputs[i]
+ for i in xrange(buckets[j][1])]
+ bucket_outputs, _ = seq2seq(bucket_encoder_inputs,
+ bucket_decoder_inputs)
+ outputs.append(bucket_outputs)
+
+ bucket_targets = [targets[i] for i in xrange(buckets[j][1])]
+ bucket_weights = [weights[i] for i in xrange(buckets[j][1])]
+ losses.append(sequence_loss(
+ outputs[-1], bucket_targets, bucket_weights, num_decoder_symbols,
+ softmax_loss_function=softmax_loss_function))
+
+ return outputs, losses
diff --git a/tensorflow/models/rnn/seq2seq_test.py b/tensorflow/models/rnn/seq2seq_test.py
new file mode 100644
index 0000000000..c5125acc21
--- /dev/null
+++ b/tensorflow/models/rnn/seq2seq_test.py
@@ -0,0 +1,384 @@
+"""Tests for functional style sequence-to-sequence models."""
+import math
+import random
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.rnn import rnn
+from tensorflow.models.rnn import rnn_cell
+from tensorflow.models.rnn import seq2seq
+
+
+class Seq2SeqTest(tf.test.TestCase):
+
+ def testRNNDecoder(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
+ _, enc_states = rnn.rnn(rnn_cell.GRUCell(2), inp, dtype=tf.float32)
+ dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
+ cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
+ dec, mem = seq2seq.rnn_decoder(dec_inp, enc_states[-1], cell)
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 4))
+
+ res = sess.run(mem)
+ self.assertEqual(len(res), 4)
+ self.assertEqual(res[0].shape, (2, 2))
+
+ def testBasicRNNSeq2Seq(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
+ dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
+ cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
+ dec, mem = seq2seq.basic_rnn_seq2seq(inp, dec_inp, cell)
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 4))
+
+ res = sess.run(mem)
+ self.assertEqual(len(res), 4)
+ self.assertEqual(res[0].shape, (2, 2))
+
+ def testTiedRNNSeq2Seq(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
+ dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
+ cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
+ dec, mem = seq2seq.tied_rnn_seq2seq(inp, dec_inp, cell)
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 4))
+
+ res = sess.run(mem)
+ self.assertEqual(len(res), 4)
+ self.assertEqual(res[0].shape, (2, 2))
+
+ def testEmbeddingRNNDecoder(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
+ cell = rnn_cell.BasicLSTMCell(2)
+ _, enc_states = rnn.rnn(cell, inp, dtype=tf.float32)
+ dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
+ dec, mem = seq2seq.embedding_rnn_decoder(dec_inp, enc_states[-1],
+ cell, 4)
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 2))
+
+ res = sess.run(mem)
+ self.assertEqual(len(res), 4)
+ self.assertEqual(res[0].shape, (2, 4))
+
+ def testEmbeddingRNNSeq2Seq(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)]
+ dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
+ cell = rnn_cell.BasicLSTMCell(2)
+ dec, mem = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp, cell, 2, 5)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 5))
+
+ res = sess.run(mem)
+ self.assertEqual(len(res), 4)
+ self.assertEqual(res[0].shape, (2, 4))
+
+ # Test externally provided output projection.
+ w = tf.get_variable("proj_w", [2, 5])
+ b = tf.get_variable("proj_b", [5])
+ with tf.variable_scope("proj_seq2seq"):
+ dec, _ = seq2seq.embedding_rnn_seq2seq(
+ enc_inp, dec_inp, cell, 2, 5, output_projection=(w, b))
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 2))
+
+ # Test that previous-feeding model ignores inputs after the first.
+ dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)]
+ tf.get_variable_scope().reuse_variables()
+ d1, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp, cell, 2, 5,
+ feed_previous=True)
+ d2, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp2, cell, 2, 5,
+ feed_previous=True)
+ d3, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp2, cell, 2, 5,
+ feed_previous=tf.constant(True))
+ res1 = sess.run(d1)
+ res2 = sess.run(d2)
+ res3 = sess.run(d3)
+ self.assertAllClose(res1, res2)
+ self.assertAllClose(res1, res3)
+
+ def testEmbeddingTiedRNNSeq2Seq(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)]
+ dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
+ cell = rnn_cell.BasicLSTMCell(2)
+ dec, mem = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp, cell, 5)
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 5))
+
+ res = sess.run(mem)
+ self.assertEqual(len(res), 4)
+ self.assertEqual(res[0].shape, (2, 4))
+
+ # Test externally provided output projection.
+ w = tf.get_variable("proj_w", [2, 5])
+ b = tf.get_variable("proj_b", [5])
+ with tf.variable_scope("proj_seq2seq"):
+ dec, _ = seq2seq.embedding_tied_rnn_seq2seq(
+ enc_inp, dec_inp, cell, 5, output_projection=(w, b))
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 2))
+
+ # Test that previous-feeding model ignores inputs after the first.
+ dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)]
+ tf.get_variable_scope().reuse_variables()
+ d1, _ = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp, cell, 5,
+ feed_previous=True)
+ d2, _ = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp2, cell, 5,
+ feed_previous=True)
+ d3, _ = seq2seq.embedding_tied_rnn_seq2seq(
+ enc_inp, dec_inp2, cell, 5, feed_previous=tf.constant(True))
+ res1 = sess.run(d1)
+ res2 = sess.run(d2)
+ res3 = sess.run(d3)
+ self.assertAllClose(res1, res2)
+ self.assertAllClose(res1, res3)
+
+ def testAttentionDecoder1(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ cell = rnn_cell.GRUCell(2)
+ inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
+ enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32)
+ attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
+ for e in enc_outputs])
+ dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
+ dec, mem = seq2seq.attention_decoder(dec_inp, enc_states[-1],
+ attn_states, cell, output_size=4)
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 4))
+
+ res = sess.run(mem)
+ self.assertEqual(len(res), 4)
+ self.assertEqual(res[0].shape, (2, 2))
+
+ def testAttentionDecoder2(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ cell = rnn_cell.GRUCell(2)
+ inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
+ enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32)
+ attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
+ for e in enc_outputs])
+ dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
+ dec, mem = seq2seq.attention_decoder(dec_inp, enc_states[-1],
+ attn_states, cell, output_size=4,
+ num_heads=2)
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 4))
+
+ res = sess.run(mem)
+ self.assertEqual(len(res), 4)
+ self.assertEqual(res[0].shape, (2, 2))
+
+ def testEmbeddingAttentionDecoder(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
+ cell = rnn_cell.GRUCell(2)
+ enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32)
+ attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
+ for e in enc_outputs])
+ dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
+ dec, mem = seq2seq.embedding_attention_decoder(dec_inp, enc_states[-1],
+ attn_states, cell, 4,
+ output_size=3)
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 3))
+
+ res = sess.run(mem)
+ self.assertEqual(len(res), 4)
+ self.assertEqual(res[0].shape, (2, 2))
+
+ def testEmbeddingAttentionSeq2Seq(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)]
+ dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
+ cell = rnn_cell.BasicLSTMCell(2)
+ dec, mem = seq2seq.embedding_attention_seq2seq(
+ enc_inp, dec_inp, cell, 2, 5)
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 5))
+
+ res = sess.run(mem)
+ self.assertEqual(len(res), 4)
+ self.assertEqual(res[0].shape, (2, 4))
+
+ # Test externally provided output projection.
+ w = tf.get_variable("proj_w", [2, 5])
+ b = tf.get_variable("proj_b", [5])
+ with tf.variable_scope("proj_seq2seq"):
+ dec, _ = seq2seq.embedding_attention_seq2seq(
+ enc_inp, dec_inp, cell, 2, 5, output_projection=(w, b))
+ sess.run([tf.variables.initialize_all_variables()])
+ res = sess.run(dec)
+ self.assertEqual(len(res), 3)
+ self.assertEqual(res[0].shape, (2, 2))
+
+ # Test that previous-feeding model ignores inputs after the first.
+ dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)]
+ tf.get_variable_scope().reuse_variables()
+ d1, _ = seq2seq.embedding_attention_seq2seq(
+ enc_inp, dec_inp, cell, 2, 5, feed_previous=True)
+ d2, _ = seq2seq.embedding_attention_seq2seq(
+ enc_inp, dec_inp2, cell, 2, 5, feed_previous=True)
+ d3, _ = seq2seq.embedding_attention_seq2seq(
+ enc_inp, dec_inp2, cell, 2, 5, feed_previous=tf.constant(True))
+ res1 = sess.run(d1)
+ res2 = sess.run(d2)
+ res3 = sess.run(d3)
+ self.assertAllClose(res1, res2)
+ self.assertAllClose(res1, res3)
+
+ def testSequenceLoss(self):
+ with self.test_session() as sess:
+ output_classes = 5
+ logits = [tf.constant(i + 0.5, shape=[2, 5]) for i in xrange(3)]
+ targets = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
+ weights = [tf.constant(1.0, shape=[2]) for i in xrange(3)]
+
+ average_loss_per_example = seq2seq.sequence_loss(
+ logits, targets, weights, output_classes,
+ average_across_timesteps=True,
+ average_across_batch=True)
+ res = sess.run(average_loss_per_example)
+ self.assertAllClose(res, 1.60944)
+
+ average_loss_per_sequence = seq2seq.sequence_loss(
+ logits, targets, weights, output_classes,
+ average_across_timesteps=False,
+ average_across_batch=True)
+ res = sess.run(average_loss_per_sequence)
+ self.assertAllClose(res, 4.828314)
+
+ total_loss = seq2seq.sequence_loss(
+ logits, targets, weights, output_classes,
+ average_across_timesteps=False,
+ average_across_batch=False)
+ res = sess.run(total_loss)
+ self.assertAllClose(res, 9.656628)
+
+ def testSequenceLossByExample(self):
+ with self.test_session() as sess:
+ output_classes = 5
+ logits = [tf.constant(i + 0.5, shape=[2, output_classes])
+ for i in xrange(3)]
+ targets = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
+ weights = [tf.constant(1.0, shape=[2]) for i in xrange(3)]
+
+ average_loss_per_example = seq2seq.sequence_loss_by_example(
+ logits, targets, weights, output_classes,
+ average_across_timesteps=True)
+ res = sess.run(average_loss_per_example)
+ self.assertAllClose(res, np.asarray([1.609438, 1.609438]))
+
+ loss_per_sequence = seq2seq.sequence_loss_by_example(
+ logits, targets, weights, output_classes,
+ average_across_timesteps=False)
+ res = sess.run(loss_per_sequence)
+ self.assertAllClose(res, np.asarray([4.828314, 4.828314]))
+
+ def testModelWithBuckets(self):
+ """Larger tests that does full sequence-to-sequence model training."""
+ # We learn to copy 10 symbols in 2 buckets: length 4 and length 8.
+ classes = 10
+ buckets = [(4, 4), (8, 8)]
+ # We use sampled softmax so we keep output projection separate.
+ w = tf.get_variable("proj_w", [24, classes])
+ w_t = tf.transpose(w)
+ b = tf.get_variable("proj_b", [classes])
+ # Here comes a sample Seq2Seq model using GRU cells.
+ def SampleGRUSeq2Seq(enc_inp, dec_inp, weights):
+ """Example sequence-to-sequence model that uses GRU cells."""
+ def GRUSeq2Seq(enc_inp, dec_inp):
+ cell = rnn_cell.MultiRNNCell([rnn_cell.GRUCell(24)] * 2)
+ return seq2seq.embedding_attention_seq2seq(
+ enc_inp, dec_inp, cell, classes, classes, output_projection=(w, b))
+ targets = [dec_inp[i+1] for i in xrange(len(dec_inp) - 1)] + [0]
+ def SampledLoss(inputs, labels):
+ labels = tf.reshape(labels, [-1, 1])
+ return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, 8, classes)
+ return seq2seq.model_with_buckets(enc_inp, dec_inp, targets, weights,
+ buckets, classes, GRUSeq2Seq,
+ softmax_loss_function=SampledLoss)
+ # Now we construct the copy model.
+ with self.test_session() as sess:
+ tf.set_random_seed(111)
+ batch_size = 32
+ inp = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)]
+ out = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)]
+ weights = [tf.ones_like(inp[0], dtype=tf.float32) for _ in xrange(8)]
+ with tf.variable_scope("root"):
+ _, losses = SampleGRUSeq2Seq(inp, out, weights)
+ updates = []
+ params = tf.all_variables()
+ optimizer = tf.train.AdamOptimizer(0.03, epsilon=1e-5)
+ for i in xrange(len(buckets)):
+ full_grads = tf.gradients(losses[i], params)
+ grads, _ = tf.clip_by_global_norm(full_grads, 30.0)
+ update = optimizer.apply_gradients(zip(grads, params))
+ updates.append(update)
+ sess.run([tf.initialize_all_variables()])
+ for ep in xrange(3):
+ log_perp = 0.0
+ for _ in xrange(50):
+ bucket = random.choice(range(len(buckets)))
+ length = buckets[bucket][0]
+ i = [np.array([np.random.randint(9) + 1 for _ in xrange(batch_size)],
+ dtype=np.int32) for _ in xrange(length)]
+ # 0 is our "GO" symbol here.
+ o = [np.array([0 for _ in xrange(batch_size)], dtype=np.int32)] + i
+ feed = {}
+ for l in xrange(length):
+ feed[inp[l].name] = i[l]
+ feed[out[l].name] = o[l]
+ if length < 8: # For the 4-bucket, we need the 5th as target.
+ feed[out[length].name] = o[length]
+ res = sess.run([updates[bucket], losses[bucket]], feed)
+ log_perp += float(res[1])
+ perp = math.exp(log_perp / 100)
+ print "step %d avg. perp %f" % ((ep + 1)*50, perp)
+ self.assertLess(perp, 2.5)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/models/rnn/translate/BUILD b/tensorflow/models/rnn/translate/BUILD
new file mode 100644
index 0000000000..0899bf689e
--- /dev/null
+++ b/tensorflow/models/rnn/translate/BUILD
@@ -0,0 +1,71 @@
+# Description:
+# Example neural translation models.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "data_utils",
+ srcs = [
+ "data_utils.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "seq2seq_model",
+ srcs = [
+ "seq2seq_model.py",
+ ],
+ deps = [
+ ":data_utils",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/models/rnn:seq2seq",
+ ],
+)
+
+py_binary(
+ name = "translate",
+ srcs = [
+ "translate.py",
+ ],
+ deps = [
+ ":data_utils",
+ ":seq2seq_model",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "translate_test",
+ size = "medium",
+ srcs = [
+ "translate.py",
+ ],
+ args = [
+ "--self_test=True",
+ ],
+ main = "translate.py",
+ deps = [
+ ":data_utils",
+ ":seq2seq_model",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/models/rnn/translate/__init__.py b/tensorflow/models/rnn/translate/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/models/rnn/translate/__init__.py
diff --git a/tensorflow/models/rnn/translate/data_utils.py b/tensorflow/models/rnn/translate/data_utils.py
new file mode 100644
index 0000000000..28bc54354c
--- /dev/null
+++ b/tensorflow/models/rnn/translate/data_utils.py
@@ -0,0 +1,264 @@
+"""Utilities for downloading data from WMT, tokenizing, vocabularies."""
+
+import gzip
+import os
+import re
+import tarfile
+import urllib
+
+from tensorflow.python.platform import gfile
+
+# Special vocabulary symbols - we always put them at the start.
+_PAD = "_PAD"
+_GO = "_GO"
+_EOS = "_EOS"
+_UNK = "_UNK"
+_START_VOCAB = [_PAD, _GO, _EOS, _UNK]
+
+PAD_ID = 0
+GO_ID = 1
+EOS_ID = 2
+UNK_ID = 3
+
+# Regular expressions used to tokenize.
+_WORD_SPLIT = re.compile("([.,!?\"':;)(])")
+_DIGIT_RE = re.compile(r"\d")
+
+# URLs for WMT data.
+_WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar"
+_WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz"
+
+
+def maybe_download(directory, filename, url):
+ """Download filename from url unless it's already in directory."""
+ if not os.path.exists(directory):
+ print "Creating directory %s" % directory
+ os.mkdir(directory)
+ filepath = os.path.join(directory, filename)
+ if not os.path.exists(filepath):
+ print "Downloading %s to %s" % (url, filepath)
+ filepath, _ = urllib.urlretrieve(url, filepath)
+ statinfo = os.stat(filepath)
+ print "Succesfully downloaded", filename, statinfo.st_size, "bytes"
+ return filepath
+
+
+def gunzip_file(gz_path, new_path):
+ """Unzips from gz_path into new_path."""
+ print "Unpacking %s to %s" % (gz_path, new_path)
+ with gzip.open(gz_path, "rb") as gz_file:
+ with open(new_path, "w") as new_file:
+ for line in gz_file:
+ new_file.write(line)
+
+
+def get_wmt_enfr_train_set(directory):
+ """Download the WMT en-fr training corpus to directory unless it's there."""
+ train_path = os.path.join(directory, "giga-fren.release2")
+ if not (gfile.Exists(train_path +".fr") and gfile.Exists(train_path +".en")):
+ corpus_file = maybe_download(directory, "training-giga-fren.tar",
+ _WMT_ENFR_TRAIN_URL)
+ print "Extracting tar file %s" % corpus_file
+ with tarfile.open(corpus_file, "r") as corpus_tar:
+ corpus_tar.extractall(directory)
+ gunzip_file(train_path + ".fr.gz", train_path + ".fr")
+ gunzip_file(train_path + ".en.gz", train_path + ".en")
+ return train_path
+
+
+def get_wmt_enfr_dev_set(directory):
+ """Download the WMT en-fr training corpus to directory unless it's there."""
+ dev_name = "newstest2013"
+ dev_path = os.path.join(directory, dev_name)
+ if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")):
+ dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL)
+ print "Extracting tgz file %s" % dev_file
+ with tarfile.open(dev_file, "r:gz") as dev_tar:
+ fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr")
+ en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en")
+ fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix.
+ en_dev_file.name = dev_name + ".en"
+ dev_tar.extract(fr_dev_file, directory)
+ dev_tar.extract(en_dev_file, directory)
+ return dev_path
+
+
+def basic_tokenizer(sentence):
+ """Very basic tokenizer: split the sentence into a list of tokens."""
+ words = []
+ for space_separated_fragment in sentence.strip().split():
+ words.extend(re.split(_WORD_SPLIT, space_separated_fragment))
+ return [w for w in words if w]
+
+
+def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
+ tokenizer=None, normalize_digits=True):
+ """Create vocabulary file (if it does not exist yet) from data file.
+
+ Data file is assumed to contain one sentence per line. Each sentence is
+ tokenized and digits are normalized (if normalize_digits is set).
+ Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
+ We write it to vocabulary_path in a one-token-per-line format, so that later
+ token in the first line gets id=0, second line gets id=1, and so on.
+
+ Args:
+ vocabulary_path: path where the vocabulary will be created.
+ data_path: data file that will be used to create vocabulary.
+ max_vocabulary_size: limit on the size of the created vocabulary.
+ tokenizer: a function to use to tokenize each data sentence;
+ if None, basic_tokenizer will be used.
+ normalize_digits: Boolean; if true, all digits are replaced by 0s.
+ """
+ if not gfile.Exists(vocabulary_path):
+ print "Creating vocabulary %s from data %s" % (vocabulary_path, data_path)
+ vocab = {}
+ with gfile.GFile(data_path, mode="r") as f:
+ counter = 0
+ for line in f:
+ counter += 1
+ if counter % 100000 == 0: print " processing line %d" % counter
+ tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
+ for w in tokens:
+ word = re.sub(_DIGIT_RE, "0", w) if normalize_digits else w
+ if word in vocab:
+ vocab[word] += 1
+ else:
+ vocab[word] = 1
+ vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
+ if len(vocab_list) > max_vocabulary_size:
+ vocab_list = vocab_list[:max_vocabulary_size]
+ with gfile.GFile(vocabulary_path, mode="w") as vocab_file:
+ for w in vocab_list:
+ vocab_file.write(w + "\n")
+
+
+def initialize_vocabulary(vocabulary_path):
+ """Initialize vocabulary from file.
+
+ We assume the vocabulary is stored one-item-per-line, so a file:
+ dog
+ cat
+ will result in a vocabulary {"dog": 0, "cat": 1}, and this function will
+ also return the reversed-vocabulary ["dog", "cat"].
+
+ Args:
+ vocabulary_path: path to the file containing the vocabulary.
+
+ Returns:
+ a pair: the vocabulary (a dictionary mapping string to integers), and
+ the reversed vocabulary (a list, which reverses the vocabulary mapping).
+
+ Raises:
+ ValueError: if the provided vocabulary_path does not exist.
+ """
+ if gfile.Exists(vocabulary_path):
+ rev_vocab = []
+ with gfile.GFile(vocabulary_path, mode="r") as f:
+ rev_vocab.extend(f.readlines())
+ rev_vocab = [line.strip() for line in rev_vocab]
+ vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
+ return vocab, rev_vocab
+ else:
+ raise ValueError("Vocabulary file %s not found.", vocabulary_path)
+
+
+def sentence_to_token_ids(sentence, vocabulary,
+ tokenizer=None, normalize_digits=True):
+ """Convert a string to list of integers representing token-ids.
+
+ For example, a sentence "I have a dog" may become tokenized into
+ ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2,
+ "a": 4, "dog": 7"} this function will return [1, 2, 4, 7].
+
+ Args:
+ sentence: a string, the sentence to convert to token-ids.
+ vocabulary: a dictionary mapping tokens to integers.
+ tokenizer: a function to use to tokenize each sentence;
+ if None, basic_tokenizer will be used.
+ normalize_digits: Boolean; if true, all digits are replaced by 0s.
+
+ Returns:
+ a list of integers, the token-ids for the sentence.
+ """
+ if tokenizer:
+ words = tokenizer(sentence)
+ else:
+ words = basic_tokenizer(sentence)
+ if not normalize_digits:
+ return [vocabulary.get(w, UNK_ID) for w in words]
+ # Normalize digits by 0 before looking words up in the vocabulary.
+ return [vocabulary.get(re.sub(_DIGIT_RE, "0", w), UNK_ID) for w in words]
+
+
+def data_to_token_ids(data_path, target_path, vocabulary_path,
+ tokenizer=None, normalize_digits=True):
+ """Tokenize data file and turn into token-ids using given vocabulary file.
+
+ This function loads data line-by-line from data_path, calls the above
+ sentence_to_token_ids, and saves the result to target_path. See comment
+ for sentence_to_token_ids on the details of token-ids format.
+
+ Args:
+ data_path: path to the data file in one-sentence-per-line format.
+ target_path: path where the file with token-ids will be created.
+ vocabulary_path: path to the vocabulary file.
+ tokenizer: a function to use to tokenize each sentence;
+ if None, basic_tokenizer will be used.
+ normalize_digits: Boolean; if true, all digits are replaced by 0s.
+ """
+ if not gfile.Exists(target_path):
+ print "Tokenizing data in %s" % data_path
+ vocab, _ = initialize_vocabulary(vocabulary_path)
+ with gfile.GFile(data_path, mode="r") as data_file:
+ with gfile.GFile(target_path, mode="w") as tokens_file:
+ counter = 0
+ for line in data_file:
+ counter += 1
+ if counter % 100000 == 0: print " tokenizing line %d" % counter
+ token_ids = sentence_to_token_ids(line, vocab, tokenizer,
+ normalize_digits)
+ tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
+
+
+def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size):
+ """Get WMT data into data_dir, create vocabularies and tokenize data.
+
+ Args:
+ data_dir: directory in which the data sets will be stored.
+ en_vocabulary_size: size of the English vocabulary to create and use.
+ fr_vocabulary_size: size of the French vocabulary to create and use.
+
+ Returns:
+ A tuple of 6 elements:
+ (1) path to the token-ids for English training data-set,
+ (2) path to the token-ids for French training data-set,
+ (3) path to the token-ids for English development data-set,
+ (4) path to the token-ids for French development data-set,
+ (5) path to the English vocabulary file,
+ (6) path to the French vocabluary file.
+ """
+ # Get wmt data to the specified directory.
+ train_path = get_wmt_enfr_train_set(data_dir)
+ dev_path = get_wmt_enfr_dev_set(data_dir)
+
+ # Create vocabularies of the appropriate sizes.
+ fr_vocab_path = os.path.join(data_dir, "vocab%d.fr" % fr_vocabulary_size)
+ en_vocab_path = os.path.join(data_dir, "vocab%d.en" % en_vocabulary_size)
+ create_vocabulary(fr_vocab_path, train_path + ".fr", fr_vocabulary_size)
+ create_vocabulary(en_vocab_path, train_path + ".en", en_vocabulary_size)
+
+ # Create token ids for the training data.
+ fr_train_ids_path = train_path + (".ids%d.fr" % fr_vocabulary_size)
+ en_train_ids_path = train_path + (".ids%d.en" % en_vocabulary_size)
+ data_to_token_ids(train_path + ".fr", fr_train_ids_path, fr_vocab_path)
+ data_to_token_ids(train_path + ".en", fr_train_ids_path, fr_vocab_path)
+
+ # Create token ids for the development data.
+ fr_dev_ids_path = dev_path + (".ids%d.fr" % fr_vocabulary_size)
+ en_dev_ids_path = dev_path + (".ids%d.en" % en_vocabulary_size)
+ data_to_token_ids(dev_path + ".fr", fr_dev_ids_path, fr_vocab_path)
+ data_to_token_ids(dev_path + ".en", en_dev_ids_path, en_vocab_path)
+
+ return (en_train_ids_path, fr_train_ids_path,
+ en_dev_ids_path, fr_dev_ids_path,
+ en_vocab_path, fr_vocab_path)
diff --git a/tensorflow/models/rnn/translate/seq2seq_model.py b/tensorflow/models/rnn/translate/seq2seq_model.py
new file mode 100644
index 0000000000..3c9cfb007f
--- /dev/null
+++ b/tensorflow/models/rnn/translate/seq2seq_model.py
@@ -0,0 +1,268 @@
+"""Sequence-to-sequence model with an attention mechanism."""
+
+import random
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.rnn import rnn_cell
+from tensorflow.models.rnn import seq2seq
+
+from tensorflow.models.rnn.translate import data_utils
+
+
+class Seq2SeqModel(object):
+ """Sequence-to-sequence model with attention and for multiple buckets.
+
+ This class implements a multi-layer recurrent neural network as encoder,
+ and an attention-based decoder. This is the same as the model described in
+ this paper: http://arxiv.org/abs/1412.7449 - please look there for details,
+ or into the seq2seq library for complete model implementation.
+ This class also allows to use GRU cells in addition to LSTM cells, and
+ sampled softmax to handle large output vocabulary size. A single-layer
+ version of this model, but with bi-directional encoder, was presented in
+ http://arxiv.org/abs/1409.0473
+ and sampled softmax is described in Section 3 of the following paper.
+ http://arxiv.org/pdf/1412.2007v2.pdf
+ """
+
+ def __init__(self, source_vocab_size, target_vocab_size, buckets, size,
+ num_layers, max_gradient_norm, batch_size, learning_rate,
+ learning_rate_decay_factor, use_lstm=False,
+ num_samples=512, forward_only=False):
+ """Create the model.
+
+ Args:
+ source_vocab_size: size of the source vocabulary.
+ target_vocab_size: size of the target vocabulary.
+ buckets: a list of pairs (I, O), where I specifies maximum input length
+ that will be processed in that bucket, and O specifies maximum output
+ length. Training instances that have inputs longer than I or outputs
+ longer than O will be pushed to the next bucket and padded accordingly.
+ We assume that the list is sorted, e.g., [(2, 4), (8, 16)].
+ size: number of units in each layer of the model.
+ num_layers: number of layers in the model.
+ max_gradient_norm: gradients will be clipped to maximally this norm.
+ batch_size: the size of the batches used during training;
+ the model construction is independent of batch_size, so it can be
+ changed after initialization if this is convenient, e.g., for decoding.
+ learning_rate: learning rate to start with.
+ learning_rate_decay_factor: decay learning rate by this much when needed.
+ use_lstm: if true, we use LSTM cells instead of GRU cells.
+ num_samples: number of samples for sampled softmax.
+ forward_only: if set, we do not construct the backward pass in the model.
+ """
+ self.source_vocab_size = source_vocab_size
+ self.target_vocab_size = target_vocab_size
+ self.buckets = buckets
+ self.batch_size = batch_size
+ self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
+ self.learning_rate_decay_op = self.learning_rate.assign(
+ self.learning_rate * learning_rate_decay_factor)
+ self.global_step = tf.Variable(0, trainable=False)
+
+ # If we use sampled softmax, we need an output projection.
+ output_projection = None
+ softmax_loss_function = None
+ # Sampled softmax only makes sense if we sample less than vocabulary size.
+ if num_samples > 0 and num_samples < self.target_vocab_size:
+ with tf.device("/cpu:0"):
+ w = tf.get_variable("proj_w", [size, self.target_vocab_size])
+ w_t = tf.transpose(w)
+ b = tf.get_variable("proj_b", [self.target_vocab_size])
+ output_projection = (w, b)
+
+ def sampled_loss(inputs, labels):
+ with tf.device("/cpu:0"):
+ labels = tf.reshape(labels, [-1, 1])
+ return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
+ self.target_vocab_size)
+ softmax_loss_function = sampled_loss
+
+ # Create the internal multi-layer cell for our RNN.
+ single_cell = rnn_cell.GRUCell(size)
+ if use_lstm:
+ single_cell = rnn_cell.BasicLSTMCell(size)
+ cell = single_cell
+ if num_layers > 1:
+ cell = rnn_cell.MultiRNNCell([single_cell] * num_layers)
+
+ # The seq2seq function: we use embedding for the input and attention.
+ def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
+ return seq2seq.embedding_attention_seq2seq(
+ encoder_inputs, decoder_inputs, cell, source_vocab_size,
+ target_vocab_size, output_projection=output_projection,
+ feed_previous=do_decode)
+
+ # Feeds for inputs.
+ self.encoder_inputs = []
+ self.decoder_inputs = []
+ self.target_weights = []
+ for i in xrange(buckets[-1][0]): # Last bucket is the biggest one.
+ self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
+ name="encoder{0}".format(i)))
+ for i in xrange(buckets[-1][1] + 1):
+ self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
+ name="decoder{0}".format(i)))
+ self.target_weights.append(tf.placeholder(tf.float32, shape=[None],
+ name="weight{0}".format(i)))
+
+ # Our targets are decoder inputs shifted by one.
+ targets = [self.decoder_inputs[i + 1]
+ for i in xrange(len(self.decoder_inputs) - 1)]
+
+ # Training outputs and losses.
+ if forward_only:
+ self.outputs, self.losses = seq2seq.model_with_buckets(
+ self.encoder_inputs, self.decoder_inputs, targets,
+ self.target_weights, buckets, self.target_vocab_size,
+ lambda x, y: seq2seq_f(x, y, True),
+ softmax_loss_function=softmax_loss_function)
+ # If we use output projection, we need to project outputs for decoding.
+ if output_projection is not None:
+ for b in xrange(len(buckets)):
+ self.outputs[b] = [tf.nn.xw_plus_b(output, output_projection[0],
+ output_projection[1])
+ for output in self.outputs[b]]
+ else:
+ self.outputs, self.losses = seq2seq.model_with_buckets(
+ self.encoder_inputs, self.decoder_inputs, targets,
+ self.target_weights, buckets, self.target_vocab_size,
+ lambda x, y: seq2seq_f(x, y, False),
+ softmax_loss_function=softmax_loss_function)
+
+ # Gradients and SGD update operation for training the model.
+ params = tf.trainable_variables()
+ if not forward_only:
+ self.gradient_norms = []
+ self.updates = []
+ opt = tf.train.GradientDescentOptimizer(self.learning_rate)
+ for b in xrange(len(buckets)):
+ gradients = tf.gradients(self.losses[b], params)
+ clipped_gradients, norm = tf.clip_by_global_norm(gradients,
+ max_gradient_norm)
+ self.gradient_norms.append(norm)
+ self.updates.append(opt.apply_gradients(
+ zip(clipped_gradients, params), global_step=self.global_step))
+
+ self.saver = tf.train.Saver(tf.all_variables())
+
+ def step(self, session, encoder_inputs, decoder_inputs, target_weights,
+ bucket_id, forward_only):
+ """Run a step of the model feeding the given inputs.
+
+ Args:
+ session: tensorflow session to use.
+ encoder_inputs: list of numpy int vectors to feed as encoder inputs.
+ decoder_inputs: list of numpy int vectors to feed as decoder inputs.
+ target_weights: list of numpy float vectors to feed as target weights.
+ bucket_id: which bucket of the model to use.
+ forward_only: whether to do the backward step or only forward.
+
+ Returns:
+ A triple consisting of gradient norm (or None if we did not do backward),
+ average perplexity, and the outputs.
+
+ Raises:
+ ValueError: if length of enconder_inputs, decoder_inputs, or
+ target_weights disagrees with bucket size for the specified bucket_id.
+ """
+ # Check if the sizes match.
+ encoder_size, decoder_size = self.buckets[bucket_id]
+ if len(encoder_inputs) != encoder_size:
+ raise ValueError("Encoder length must be equal to the one in bucket,"
+ " %d != %d." % (len(encoder_inputs), encoder_size))
+ if len(decoder_inputs) != decoder_size:
+ raise ValueError("Decoder length must be equal to the one in bucket,"
+ " %d != %d." % (len(decoder_inputs), decoder_size))
+ if len(target_weights) != decoder_size:
+ raise ValueError("Weights length must be equal to the one in bucket,"
+ " %d != %d." % (len(target_weights), decoder_size))
+
+ # Input feed: encoder inputs, decoder inputs, target_weights, as provided.
+ input_feed = {}
+ for l in xrange(encoder_size):
+ input_feed[self.encoder_inputs[l].name] = encoder_inputs[l]
+ for l in xrange(decoder_size):
+ input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
+ input_feed[self.target_weights[l].name] = target_weights[l]
+
+ # Since our targets are decoder inputs shifted by one, we need one more.
+ last_target = self.decoder_inputs[decoder_size].name
+ input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)
+
+ # Output feed: depends on whether we do a backward step or not.
+ if not forward_only:
+ output_feed = [self.updates[bucket_id], # Update Op that does SGD.
+ self.gradient_norms[bucket_id], # Gradient norm.
+ self.losses[bucket_id]] # Loss for this batch.
+ else:
+ output_feed = [self.losses[bucket_id]] # Loss for this batch.
+ for l in xrange(decoder_size): # Output logits.
+ output_feed.append(self.outputs[bucket_id][l])
+
+ outputs = session.run(output_feed, input_feed)
+ if not forward_only:
+ return outputs[1], outputs[2], None # Gradient norm, loss, no outputs.
+ else:
+ return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs.
+
+ def get_batch(self, data, bucket_id):
+ """Get a random batch of data from the specified bucket, prepare for step.
+
+ To feed data in step(..) it must be a list of batch-major vectors, while
+ data here contains single length-major cases. So the main logic of this
+ function is to re-index data cases to be in the proper format for feeding.
+
+ Args:
+ data: a tuple of size len(self.buckets) in which each element contains
+ lists of pairs of input and output data that we use to create a batch.
+ bucket_id: integer, which bucket to get the batch for.
+
+ Returns:
+ The triple (encoder_inputs, decoder_inputs, target_weights) for
+ the constructed batch that has the proper format to call step(...) later.
+ """
+ encoder_size, decoder_size = self.buckets[bucket_id]
+ encoder_inputs, decoder_inputs = [], []
+
+ # Get a random batch of encoder and decoder inputs from data,
+ # pad them if needed, reverse encoder inputs and add GO to decoder.
+ for _ in xrange(self.batch_size):
+ encoder_input, decoder_input = random.choice(data[bucket_id])
+
+ # Encoder inputs are padded and then reversed.
+ encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
+ encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
+
+ # Decoder inputs get an extra "GO" symbol, and are padded then.
+ decoder_pad_size = decoder_size - len(decoder_input) - 1
+ decoder_inputs.append([data_utils.GO_ID] + decoder_input +
+ [data_utils.PAD_ID] * decoder_pad_size)
+
+ # Now we create batch-major vectors from the data selected above.
+ batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
+
+ # Batch encoder inputs are just re-indexed encoder_inputs.
+ for length_idx in xrange(encoder_size):
+ batch_encoder_inputs.append(
+ np.array([encoder_inputs[batch_idx][length_idx]
+ for batch_idx in xrange(self.batch_size)], dtype=np.int32))
+
+ # Batch decoder inputs are re-indexed decoder_inputs, we create weights.
+ for length_idx in xrange(decoder_size):
+ batch_decoder_inputs.append(
+ np.array([decoder_inputs[batch_idx][length_idx]
+ for batch_idx in xrange(self.batch_size)], dtype=np.int32))
+
+ # Create target_weights to be 0 for targets that are padding.
+ batch_weight = np.ones(self.batch_size, dtype=np.float32)
+ for batch_idx in xrange(self.batch_size):
+ # We set weight to 0 if the corresponding target is a PAD symbol.
+ # The corresponding target is decoder_input shifted by 1 forward.
+ if length_idx < decoder_size - 1:
+ target = decoder_inputs[batch_idx][length_idx + 1]
+ if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
+ batch_weight[batch_idx] = 0.0
+ batch_weights.append(batch_weight)
+ return batch_encoder_inputs, batch_decoder_inputs, batch_weights
diff --git a/tensorflow/models/rnn/translate/translate.py b/tensorflow/models/rnn/translate/translate.py
new file mode 100644
index 0000000000..abf4c7c57b
--- /dev/null
+++ b/tensorflow/models/rnn/translate/translate.py
@@ -0,0 +1,260 @@
+"""Binary for training translation models and decoding from them.
+
+Running this program without --decode will download the WMT corpus into
+the directory specified as --data_dir and tokenize it in a very basic way,
+and then start training a model saving checkpoints to --train_dir.
+
+Running with --decode starts an interactive loop so you can see how
+the current checkpoint translates English sentences into French.
+
+See the following papers for more information on neural translation models.
+ * http://arxiv.org/abs/1409.3215
+ * http://arxiv.org/abs/1409.0473
+ * http://arxiv.org/pdf/1412.2007v2.pdf
+"""
+
+import math
+import os
+import random
+import sys
+import time
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.rnn.translate import data_utils
+from tensorflow.models.rnn.translate import seq2seq_model
+from tensorflow.python.platform import gfile
+
+
+tf.app.flags.DEFINE_float("learning_rate", 0.5, "Learning rate.")
+tf.app.flags.DEFINE_float("learning_rate_decay_factor", 0.99,
+ "Learning rate decays by this much.")
+tf.app.flags.DEFINE_float("max_gradient_norm", 5.0,
+ "Clip gradients to this norm.")
+tf.app.flags.DEFINE_integer("batch_size", 64,
+ "Batch size to use during training.")
+tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.")
+tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
+tf.app.flags.DEFINE_integer("en_vocab_size", 40000, "English vocabulary size.")
+tf.app.flags.DEFINE_integer("fr_vocab_size", 40000, "French vocabulary size.")
+tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory")
+tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.")
+tf.app.flags.DEFINE_integer("max_train_data_size", 0,
+ "Limit on the size of training data (0: no limit).")
+tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200,
+ "How many training steps to do per checkpoint.")
+tf.app.flags.DEFINE_boolean("decode", False,
+ "Set to True for interactive decoding.")
+tf.app.flags.DEFINE_boolean("self_test", False,
+ "Run a self-test if this is set to True.")
+
+FLAGS = tf.app.flags.FLAGS
+
+# We use a number of buckets and pad to the closest one for efficiency.
+# See seq2seq_model.Seq2SeqModel for details of how they work.
+_buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
+
+
+def read_data(source_path, target_path, max_size=None):
+ """Read data from source and target files and put into buckets.
+
+ Args:
+ source_path: path to the files with token-ids for the source language.
+ target_path: path to the file with token-ids for the target language;
+ it must be aligned with the source file: n-th line contains the desired
+ output for n-th line from the source_path.
+ max_size: maximum number of lines to read, all other will be ignored;
+ if 0 or None, data files will be read completely (no limit).
+
+ Returns:
+ data_set: a list of length len(_buckets); data_set[n] contains a list of
+ (source, target) pairs read from the provided data files that fit
+ into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
+ len(target) < _buckets[n][1]; source and target are lists of token-ids.
+ """
+ data_set = [[] for _ in _buckets]
+ with gfile.GFile(source_path, mode="r") as source_file:
+ with gfile.GFile(target_path, mode="r") as target_file:
+ source, target = source_file.readline(), target_file.readline()
+ counter = 0
+ while source and target and (not max_size or counter < max_size):
+ counter += 1
+ if counter % 100000 == 0:
+ print " reading data line %d" % counter
+ sys.stdout.flush()
+ source_ids = [int(x) for x in source.split()]
+ target_ids = [int(x) for x in target.split()]
+ target_ids.append(data_utils.EOS_ID)
+ for bucket_id, (source_size, target_size) in enumerate(_buckets):
+ if len(source_ids) < source_size and len(target_ids) < target_size:
+ data_set[bucket_id].append([source_ids, target_ids])
+ break
+ source, target = source_file.readline(), target_file.readline()
+ return data_set
+
+
+def create_model(session, forward_only):
+ """Create translation model and initialize or load parameters in session."""
+ model = seq2seq_model.Seq2SeqModel(
+ FLAGS.en_vocab_size, FLAGS.fr_vocab_size, _buckets,
+ FLAGS.size, FLAGS.num_layers, FLAGS.max_gradient_norm, FLAGS.batch_size,
+ FLAGS.learning_rate, FLAGS.learning_rate_decay_factor,
+ forward_only=forward_only)
+ ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
+ if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
+ print "Reading model parameters from %s" % ckpt.model_checkpoint_path
+ model.saver.restore(session, ckpt.model_checkpoint_path)
+ else:
+ print "Created model with fresh parameters."
+ session.run(tf.variables.initialize_all_variables())
+ return model
+
+
+def train():
+ """Train a en->fr translation model using WMT data."""
+ # Prepare WMT data.
+ print "Preparing WMT data in %s" % FLAGS.data_dir
+ en_train, fr_train, en_dev, fr_dev, _, _ = data_utils.prepare_wmt_data(
+ FLAGS.data_dir, FLAGS.en_vocab_size, FLAGS.fr_vocab_size)
+
+ with tf.Session() as sess:
+ # Create model.
+ print "Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)
+ model = create_model(sess, False)
+
+ # Read data into buckets and compute their sizes.
+ print ("Reading development and training data (limit: %d)."
+ % FLAGS.max_train_data_size)
+ dev_set = read_data(en_dev, fr_dev)
+ train_set = read_data(en_train, fr_train, FLAGS.max_train_data_size)
+ train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
+ train_total_size = float(sum(train_bucket_sizes))
+
+ # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
+ # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
+ # the size if i-th training bucket, as used later.
+ train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size
+ for i in xrange(len(train_bucket_sizes))]
+
+ # This is the training loop.
+ step_time, loss = 0.0, 0.0
+ current_step = 0
+ previous_losses = []
+ while True:
+ # Choose a bucket according to data distribution. We pick a random number
+ # in [0, 1] and use the corresponding interval in train_buckets_scale.
+ random_number_01 = np.random.random_sample()
+ bucket_id = min([i for i in xrange(len(train_buckets_scale))
+ if train_buckets_scale[i] > random_number_01])
+
+ # Get a batch and make a step.
+ start_time = time.time()
+ encoder_inputs, decoder_inputs, target_weights = model.get_batch(
+ train_set, bucket_id)
+ _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
+ target_weights, bucket_id, False)
+ step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint
+ loss += step_loss / FLAGS.steps_per_checkpoint
+ current_step += 1
+
+ # Once in a while, we save checkpoint, print statistics, and run evals.
+ if current_step % FLAGS.steps_per_checkpoint == 0:
+ # Print statistics for the previous epoch.
+ perplexity = math.exp(loss) if loss < 300 else float('inf')
+ print ("global step %d learning rate %.4f step-time %.2f perplexity "
+ "%.2f" % (model.global_step.eval(), model.learning_rate.eval(),
+ step_time, perplexity))
+ # Decrease learning rate if no improvement was seen over last 3 times.
+ if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
+ sess.run(model.learning_rate_decay_op)
+ previous_losses.append(loss)
+ # Save checkpoint and zero timer and loss.
+ checkpoint_path = os.path.join(FLAGS.train_dir, "translate.ckpt")
+ model.saver.save(sess, checkpoint_path, global_step=model.global_step)
+ step_time, loss = 0.0, 0.0
+ # Run evals on development set and print their perplexity.
+ for bucket_id in xrange(len(_buckets)):
+ encoder_inputs, decoder_inputs, target_weights = model.get_batch(
+ dev_set, bucket_id)
+ _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
+ target_weights, bucket_id, True)
+ eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
+ print " eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx)
+ sys.stdout.flush()
+
+
+def decode():
+ with tf.Session() as sess:
+ # Create model and load parameters.
+ model = create_model(sess, True)
+ model.batch_size = 1 # We decode one sentence at a time.
+
+ # Load vocabularies.
+ en_vocab_path = os.path.join(FLAGS.data_dir,
+ "vocab%d.en" % FLAGS.en_vocab_size)
+ fr_vocab_path = os.path.join(FLAGS.data_dir,
+ "vocab%d.fr" % FLAGS.fr_vocab_size)
+ en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
+ _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)
+
+ # Decode from standard input.
+ sys.stdout.write("> ")
+ sys.stdout.flush()
+ sentence = sys.stdin.readline()
+ while sentence:
+ # Get token-ids for the input sentence.
+ token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab)
+ # Which bucket does it belong to?
+ bucket_id = min([b for b in xrange(len(_buckets))
+ if _buckets[b][0] > len(token_ids)])
+ # Get a 1-element batch to feed the sentence to the model.
+ encoder_inputs, decoder_inputs, target_weights = model.get_batch(
+ {bucket_id: [(token_ids, [])]}, bucket_id)
+ # Get output logits for the sentence.
+ _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
+ target_weights, bucket_id, True)
+ # This is a greedy decoder - outputs are just argmaxes of output_logits.
+ outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
+ # If there is an EOS symbol in outputs, cut them at that point.
+ if data_utils.EOS_ID in outputs:
+ outputs = outputs[:outputs.index(data_utils.EOS_ID)]
+ # Print out French sentence corresponding to outputs.
+ print " ".join([rev_fr_vocab[output] for output in outputs])
+ print "> ",
+ sys.stdout.flush()
+ sentence = sys.stdin.readline()
+
+
+def self_test():
+ """Test the translation model."""
+ with tf.Session() as sess:
+ print "Self-test for neural translation model."
+ # Create model with vocabularies of 10, 2 small buckets, 2 layers of 32.
+ model = seq2seq_model.Seq2SeqModel(10, 10, [(3, 3), (6, 6)], 32, 2,
+ 5.0, 32, 0.3, 0.99, num_samples=8)
+ sess.run(tf.variables.initialize_all_variables())
+
+ # Fake data set for both the (3, 3) and (6, 6) bucket.
+ data_set = ([([1, 1], [2, 2]), ([3, 3], [4]), ([5], [6])],
+ [([1, 1, 1, 1, 1], [2, 2, 2, 2, 2]), ([3, 3, 3], [5, 6])])
+ for _ in xrange(5): # Train the fake model for 5 steps.
+ bucket_id = random.choice([0, 1])
+ encoder_inputs, decoder_inputs, target_weights = model.get_batch(
+ data_set, bucket_id)
+ model.step(sess, encoder_inputs, decoder_inputs, target_weights,
+ bucket_id, False)
+
+
+def main(_):
+ if FLAGS.self_test:
+ self_test()
+ elif FLAGS.decode:
+ decode()
+ else:
+ train()
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py
index fcf269c717..99d02a5380 100644
--- a/tensorflow/python/framework/docs.py
+++ b/tensorflow/python/framework/docs.py
@@ -233,6 +233,14 @@ class Library(Document):
# signatures.
continue
args_list.append(arg)
+
+ # TODO(mrry): This is a workaround for documenting signature of
+ # functions that have the @contextlib.contextmanager decorator.
+ # We should do something better.
+ if argspec.varargs == "args" and argspec.keywords == "kwds":
+ original_func = func.func_closure[0].cell_contents
+ return self._generate_signature_for_function(original_func)
+
if argspec.defaults:
for arg, default in zip(
argspec.args[first_arg_with_default:], argspec.defaults):
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index e82c45d95f..8256acc514 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -76,7 +76,8 @@ def all_libraries(module_to_name, members, documented):
"xw_plus_b", "relu_layer", "lrn",
"batch_norm_with_global_normalization",
"batch_norm_with_global_normalization_grad",
- "all_candidate_sampler"],
+ "all_candidate_sampler",
+ "embedding_lookup_sparse"],
prefix=PREFIX_TEXT),
library('client', "Running Graphs", client_lib,
exclude_symbols=["InteractiveSession"]),
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index bc64593d23..d9a377bab5 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -8,24 +8,31 @@ from tensorflow.python.ops import math_ops
def embedding_lookup(params, ids, name=None):
- """Return a tensor of embedding values by looking up "ids" in "params".
+ """Looks up `ids` in a list of embedding tensors.
+
+ This function is used to perform parallel lookups on the list of
+ tensors in `params`. It is a generalization of
+ [`tf.gather()`](array_ops.md#gather), where `params` is interpreted
+ as a partition of a larger embedding tensor.
+
+ If `len(params) > 1`, each element `id` of `ids` is partitioned between
+ the elements of `params` by computing `p = id % len(params)`, and is
+ then used to look up the slice `params[p][id // len(params), ...]`.
+
+ The results of the lookup are then concatenated into a dense
+ tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
Args:
- params: List of tensors of the same shape. A single tensor is
- treated as a singleton list.
- ids: Tensor of integers containing the ids to be looked up in
- 'params'. Let P be len(params). If P > 1, then the ids are
- partitioned by id % P, and we do separate lookups in params[p]
- for 0 <= p < P, and then stitch the results back together into
- a single result tensor.
- name: Optional name for the op.
+ params: A list of tensors with the same shape and type.
+ ids: A `Tensor` with type `int32` containing the ids to be looked
+ up in `params`.
+ name: A name for the operation (optional).
Returns:
- A tensor of shape ids.shape + params[0].shape[1:] containing the
- values params[i % P][i] for each i in ids.
+ A `Tensor` with the same type as the tensors in `params`.
Raises:
- ValueError: if some parameters are invalid.
+ ValueError: If `params` is empty.
"""
if not isinstance(params, list):
params = [params]
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 8eccc63eb1..004ceb51c6 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -43,10 +43,10 @@ are as follows. If the 4-D `input` has shape
`[batch, in_height, in_width, ...]` and the 4-D `filter` has shape
`[filter_height, filter_width, ...]`, then
- output.shape = [batch,
- (in_height - filter_height + 1) / strides[1],
- (in_width - filter_width + 1) / strides[2],
- ...]
+ shape(output) = [batch,
+ (in_height - filter_height + 1) / strides[1],
+ (in_width - filter_width + 1) / strides[2],
+ ...]
output[b, i, j, :] =
sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] *
@@ -58,7 +58,7 @@ vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]`
is multiplied by a vector `filter[di, dj, k]`, and all the vectors are
concatenated.
-In the formula for `output.shape`, the rounding direction depends on padding:
+In the formula for `shape(output)`, the rounding direction depends on padding:
* `padding = 'SAME'`: Round down (only full size windows are considered).
* `padding = 'VALID'`: Round up (partial windows are included).
@@ -81,7 +81,7 @@ In detail, the output is
for each tuple of indices `i`. The output shape is
- output.shape = (value.shape - ksize + 1) / strides
+ shape(output) = (shape(value) - ksize + 1) / strides
where the rounding direction depends on padding:
@@ -119,10 +119,10 @@ TensorFlow provides several operations that help you perform classification.
## Embeddings
-TensorFlow provides several operations that help you compute embeddings.
+TensorFlow provides library support for looking up values in embedding
+tensors.
@@embedding_lookup
-@@embedding_lookup_sparse
## Evaluation
@@ -336,15 +336,16 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
By default, each element is kept or dropped independently. If `noise_shape`
is specified, it must be
[broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
- to the shape of `x`, and only dimensions with `noise_shape[i] == x.shape[i]`
- will make independent decisions. For example, if `x.shape = [b, x, y, c]` and
- `noise_shape = [b, 1, 1, c]`, each batch and channel component will be
+ to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
+ will make independent decisions. For example, if `shape(x) = [k, l, m, n]`
+ and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be
kept independently and each row and column will be kept or not kept together.
Args:
x: A tensor.
- keep_prob: Float probability that each element is kept.
- noise_shape: Shape for randomly generated keep/drop flags.
+ keep_prob: A Python float. The probability that each element is kept.
+ noise_shape: A 1-D `Tensor` of type `int32`, representing the
+ shape for randomly generated keep/drop flags.
seed: A Python integer. Used to create a random seed.
See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
name: A name for this operation (optional).
diff --git a/tensorflow/tools/docker/Dockerfile.cpu b/tensorflow/tools/docker/Dockerfile.cpu
index c93a6e8bd2..da0a7abb18 100644
--- a/tensorflow/tools/docker/Dockerfile.cpu
+++ b/tensorflow/tools/docker/Dockerfile.cpu
@@ -66,3 +66,7 @@ RUN bazel clean && \
bazel build -c opt tensorflow/tools/docker:simple_console
ENV PYTHONPATH=/tensorflow/bazel-bin/tensorflow/tools/docker/simple_console.runfiles/:$PYTHONPATH
+
+# We want to start Jupyter in the directory with our getting started
+# tutorials.
+WORKDIR /notebooks