aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/grid_rnn
diff options
context:
space:
mode:
authorGravatar Illia Polosukhin <ilblackdragon@gmail.com>2016-04-18 17:56:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-18 19:03:29 -0700
commit5c9bc51857bc0c330d3ab976871ee3509647d1e7 (patch)
treea58def7cbf316c6e091b3b36657f120f1388ec54 /tensorflow/contrib/grid_rnn
parentfc432e37a7ddd408ff09a7b90b1c4cd5af1b134e (diff)
Merge changes from github.
Change: 120185825
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r--tensorflow/contrib/grid_rnn/BUILD39
-rw-r--r--tensorflow/contrib/grid_rnn/__init__.py27
-rw-r--r--tensorflow/contrib/grid_rnn/python/__init__.py18
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py489
-rw-r--r--tensorflow/contrib/grid_rnn/python/ops/__init__.py18
-rw-r--r--tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py352
6 files changed, 943 insertions, 0 deletions
diff --git a/tensorflow/contrib/grid_rnn/BUILD b/tensorflow/contrib/grid_rnn/BUILD
new file mode 100644
index 0000000000..c3b9b5a9dd
--- /dev/null
+++ b/tensorflow/contrib/grid_rnn/BUILD
@@ -0,0 +1,39 @@
+# Description:
+# Contains classes to construct GridRNN cells
+# APIs here are meant to evolve over time.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+
+py_library(
+ name = "grid_rnn_py",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ srcs_version = "PY2AND3",
+)
+
+cuda_py_tests(
+ name = "grid_rnn_test",
+ srcs = ["python/kernel_tests/grid_rnn_test.py"],
+ additional_deps = [
+ ":grid_rnn_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/grid_rnn/__init__.py b/tensorflow/contrib/grid_rnn/__init__.py
new file mode 100644
index 0000000000..fbe380f1f3
--- /dev/null
+++ b/tensorflow/contrib/grid_rnn/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""GridRNN cells
+
+## This package provides classes for GridRNN
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import, line-too-long
+from tensorflow.contrib.grid_rnn.python.ops.grid_rnn_cell import *
diff --git a/tensorflow/contrib/grid_rnn/python/__init__.py b/tensorflow/contrib/grid_rnn/python/__init__.py
new file mode 100644
index 0000000000..94872f4f0c
--- /dev/null
+++ b/tensorflow/contrib/grid_rnn/python/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
new file mode 100644
index 0000000000..d79cafe944
--- /dev/null
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -0,0 +1,489 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for GridRNN cells."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+
+class GridRNNCellTest(tf.test.TestCase):
+
+ def testGrid2BasicLSTMCell(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)) as root_scope:
+ x = tf.zeros([1, 3])
+ m = tf.zeros([1, 8])
+ cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2)
+ self.assertEqual(cell.state_size, 8)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (1, 2))
+ self.assertEqual(s.get_shape(), (1, 8))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 8))
+ self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]])
+ self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181,
+ 0.72320831, 0.80555487, 0.39102408, 0.42150158]])
+
+ # emulate a loop through the input sequence, where we call cell() multiple times
+ root_scope.reuse_variables()
+ g2, s2 = cell(x, m)
+ self.assertEqual(g2.get_shape(), (1, 2))
+ self.assertEqual(s2.get_shape(), (1, 8))
+
+ res = sess.run([g2, s2], {x: np.array([[2., 2., 2.]]), m: res[1]})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 8))
+ self.assertAllClose(res[0], [[0.58847463, 0.58847463]])
+ self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463, 0.58847463,
+ 0.97726452, 1.04626071, 0.4927212, 0.51137757]])
+
+ def testGrid2BasicLSTMCellTied(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)):
+ x = tf.zeros([1, 3])
+ m = tf.zeros([1, 8])
+ cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2, tied=True)
+ self.assertEqual(cell.state_size, 8)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (1, 2))
+ self.assertEqual(s.get_shape(), (1, 8))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 8))
+ self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]])
+ self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181,
+ 0.72320831, 0.80555487, 0.39102408, 0.42150158]])
+
+ res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res[1]})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 8))
+ self.assertAllClose(res[0], [[0.36703536, 0.36703536]])
+ self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536, 0.36703536,
+ 0.80941606, 0.87550586, 0.40108523, 0.42199609]])
+
+ def testGrid2BasicLSTMCellWithRelu(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)):
+ x = tf.zeros([1, 3])
+ m = tf.zeros([1, 4])
+ cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2, tied=False, non_recurrent_fn=tf.nn.relu)
+ self.assertEqual(cell.state_size, 4)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (1, 2))
+ self.assertEqual(s.get_shape(), (1, 4))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4]])})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 4))
+ self.assertAllClose(res[0], [[ 0.31667367, 0.31667367]])
+ self.assertAllClose(res[1], [[ 0.29530135, 0.37520045, 0.17044567, 0.21292259]])
+
+ """
+ LSTMCell
+ """
+
+ def testGrid2LSTMCell(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 3])
+ m = tf.zeros([1, 8])
+ cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, use_peepholes=True)
+ self.assertEqual(cell.state_size, 8)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (1, 2))
+ self.assertEqual(s.get_shape(), (1, 8))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 8))
+ self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]])
+ self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918,
+ 1.38917875, 1.49043763, 0.83884692, 0.86036491]])
+
+ def testGrid2LSTMCellTied(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 3])
+ m = tf.zeros([1, 8])
+ cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, tied=True, use_peepholes=True)
+ self.assertEqual(cell.state_size, 8)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (1, 2))
+ self.assertEqual(s.get_shape(), (1, 8))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 8))
+ self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]])
+ self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918,
+ 1.38917875, 1.49043763, 0.83884692, 0.86036491]])
+
+ def testGrid2LSTMCellWithRelu(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 3])
+ m = tf.zeros([1, 4])
+ cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, use_peepholes=True, non_recurrent_fn=tf.nn.relu)
+ self.assertEqual(cell.state_size, 4)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (1, 2))
+ self.assertEqual(s.get_shape(), (1, 4))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4]])})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 4))
+ self.assertAllClose(res[0], [[ 2.1831727, 2.1831727]])
+ self.assertAllClose(res[1], [[ 0.92270052, 1.02325559, 0.66159075, 0.70475441]])
+
+ """
+ RNNCell
+ """
+
+ def testGrid2BasicRNNCell(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([2, 2])
+ m = tf.zeros([2, 4])
+ cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2)
+ self.assertEqual(cell.state_size, 4)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (2, 2))
+ self.assertEqual(s.get_shape(), (2, 4))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]),
+ m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
+ self.assertEqual(res[0].shape, (2, 2))
+ self.assertEqual(res[1].shape, (2, 4))
+ self.assertAllClose(res[0], [[0.94685763, 0.94685763],
+ [0.99480951, 0.99480951]])
+ self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
+ [0.99480951, 0.99480951, 0.97574311, 0.97574311]])
+
+ def testGrid2BasicRNNCellTied(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([2, 2])
+ m = tf.zeros([2, 4])
+ cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2, tied=True)
+ self.assertEqual(cell.state_size, 4)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (2, 2))
+ self.assertEqual(s.get_shape(), (2, 4))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]),
+ m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
+ self.assertEqual(res[0].shape, (2, 2))
+ self.assertEqual(res[1].shape, (2, 4))
+ self.assertAllClose(res[0], [[0.94685763, 0.94685763],
+ [0.99480951, 0.99480951]])
+ self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
+ [0.99480951, 0.99480951, 0.97574311, 0.97574311]])
+
+ def testGrid2BasicRNNCellWithRelu(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 2])
+ m = tf.zeros([1, 2])
+ cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2, non_recurrent_fn=tf.nn.relu)
+ self.assertEqual(cell.state_size, 2)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (1, 2))
+ self.assertEqual(s.get_shape(), (1, 2))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1.]]), m: np.array([[0.1, 0.1]])})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 2))
+ self.assertAllClose(res[0], [[1.80049896, 1.80049896]])
+ self.assertAllClose(res[1], [[0.80049896, 0.80049896]])
+
+ """
+ 1-LSTM
+ """
+
+ def testGrid1LSTMCell(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)) as root_scope:
+ x = tf.zeros([1, 3])
+ m = tf.zeros([1, 4])
+ cell = tf.contrib.grid_rnn.Grid1LSTMCell(2, use_peepholes=True)
+ self.assertEqual(cell.state_size, 4)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (1, 2))
+ self.assertEqual(s.get_shape(), (1, 4))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4]])})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 4))
+ self.assertAllClose(res[0], [[0.91287315, 0.91287315]])
+ self.assertAllClose(res[1], [[2.26285243, 2.26285243, 0.91287315, 0.91287315]])
+
+ root_scope.reuse_variables()
+
+ x2 = tf.zeros([0, 0])
+ g2, s2 = cell(x2, m)
+ self.assertEqual(g2.get_shape(), (1, 2))
+ self.assertEqual(s2.get_shape(), (1, 4))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g2, s2], {m: res[1]})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 4))
+ self.assertAllClose(res[0], [[0.9032144, 0.9032144]])
+ self.assertAllClose(res[1], [[2.79966092, 2.79966092, 0.9032144, 0.9032144]])
+
+ g3, s3 = cell(x2, m)
+ self.assertEqual(g3.get_shape(), (1, 2))
+ self.assertEqual(s3.get_shape(), (1, 4))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g3, s3], {m: res[1]})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 4))
+ self.assertAllClose(res[0], [[0.92727238, 0.92727238]])
+ self.assertAllClose(res[1], [[3.3529923, 3.3529923, 0.92727238, 0.92727238]])
+
+ """
+ 3-LSTM
+ """
+ def testGrid3LSTMCell(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 3])
+ m = tf.zeros([1, 12])
+ cell = tf.contrib.grid_rnn.Grid3LSTMCell(2, use_peepholes=True)
+ self.assertEqual(cell.state_size, 12)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (1, 2))
+ self.assertEqual(s.get_shape(), (1, 12))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.1, -0.2, -0.3, -0.4]])})
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 12))
+
+ self.assertAllClose(res[0], [[0.96892911, 0.96892911]])
+ self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911, 0.96892911,
+ 1.33592629, 1.4373529, 0.80867189, 0.83247656,
+ 0.7317788, 0.63205892, 0.56548983, 0.50446129]])
+
+ """
+ Edge cases
+ """
+ def testGridRNNEdgeCasesLikeRelu(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([3, 2])
+ m = tf.zeros([0, 0])
+
+ # this is equivalent to relu
+ cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=1, input_dims=0, output_dims=0,
+ non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu)
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (3, 2))
+ self.assertEqual(s.get_shape(), (0, 0))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., -1.], [-2, 1], [2, -1]])})
+ self.assertEqual(res[0].shape, (3, 2))
+ self.assertEqual(res[1].shape, (0, 0))
+ self.assertAllClose(res[0], [[0, 0], [0, 0], [0.5, 0.5]])
+
+ def testGridRNNEdgeCasesNoOutput(self):
+ with self.test_session() as sess:
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([1, 2])
+ m = tf.zeros([1, 4])
+
+ # This cell produces no output
+ cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=2, input_dims=0, output_dims=None,
+ non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu)
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (0, 0))
+ self.assertEqual(s.get_shape(), (1, 4))
+
+ sess.run([tf.initialize_all_variables()])
+ res = sess.run([g, s], {x: np.array([[1., 1.]]),
+ m: np.array([[0.1, 0.1, 0.1, 0.1]])})
+ self.assertEqual(res[0].shape, (0, 0))
+ self.assertEqual(res[1].shape, (1, 4))
+
+ """
+ Test with tf.nn.rnn
+ """
+
+ def testGrid2LSTMCellWithRNN(self):
+ batch_size = 3
+ input_size = 5
+ max_length = 6 # unrolled up to this length
+ num_units = 2
+
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units)
+
+ inputs = max_length * [
+ tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+
+ outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+
+ self.assertEqual(len(outputs), len(inputs))
+ self.assertEqual(state.get_shape(), (batch_size, 8))
+
+ for out, inp in zip(outputs, inputs):
+ self.assertEqual(out.get_shape()[0], inp.get_shape()[0])
+ self.assertEqual(out.get_shape()[1], num_units)
+ self.assertEqual(out.dtype, inp.dtype)
+
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+
+ input_value = np.ones((batch_size, input_size))
+ values = sess.run(outputs + [state],
+ feed_dict={inputs[0]: input_value})
+ for v in values:
+ self.assertTrue(np.all(np.isfinite(v)))
+
+ def testGrid2LSTMCellReLUWithRNN(self):
+ batch_size = 3
+ input_size = 5
+ max_length = 6 # unrolled up to this length
+ num_units = 2
+
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu)
+
+ inputs = max_length * [
+ tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+
+ outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+
+ self.assertEqual(len(outputs), len(inputs))
+ self.assertEqual(state.get_shape(), (batch_size, 4))
+
+ for out, inp in zip(outputs, inputs):
+ self.assertEqual(out.get_shape()[0], inp.get_shape()[0])
+ self.assertEqual(out.get_shape()[1], num_units)
+ self.assertEqual(out.dtype, inp.dtype)
+
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+
+ input_value = np.ones((batch_size, input_size))
+ values = sess.run(outputs + [state],
+ feed_dict={inputs[0]: input_value})
+ for v in values:
+ self.assertTrue(np.all(np.isfinite(v)))
+
+ def testGrid3LSTMCellReLUWithRNN(self):
+ batch_size = 3
+ input_size = 5
+ max_length = 6 # unrolled up to this length
+ num_units = 2
+
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ cell = tf.contrib.grid_rnn.Grid3LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu)
+
+ inputs = max_length * [
+ tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+
+ outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+
+ self.assertEqual(len(outputs), len(inputs))
+ self.assertEqual(state.get_shape(), (batch_size, 8))
+
+ for out, inp in zip(outputs, inputs):
+ self.assertEqual(out.get_shape()[0], inp.get_shape()[0])
+ self.assertEqual(out.get_shape()[1], num_units)
+ self.assertEqual(out.dtype, inp.dtype)
+
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+
+ input_value = np.ones((batch_size, input_size))
+ values = sess.run(outputs + [state],
+ feed_dict={inputs[0]: input_value})
+ for v in values:
+ self.assertTrue(np.all(np.isfinite(v)))
+
+
+ def testGrid1LSTMCellWithRNN(self):
+ batch_size = 3
+ input_size = 5
+ max_length = 6 # unrolled up to this length
+ num_units = 2
+
+ with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
+ cell = tf.contrib.grid_rnn.Grid1LSTMCell(num_units=num_units)
+
+ # for 1-LSTM, we only feed the first step
+ inputs = [tf.placeholder(tf.float32, shape=(batch_size, input_size))] \
+ + (max_length - 1) * [tf.zeros([0, 0])]
+
+ outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+
+ self.assertEqual(len(outputs), len(inputs))
+ self.assertEqual(state.get_shape(), (batch_size, 4))
+
+ for out, inp in zip(outputs, inputs):
+ self.assertEqual(out.get_shape(), (3, num_units))
+ self.assertEqual(out.dtype, inp.dtype)
+
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+
+ input_value = np.ones((batch_size, input_size))
+ values = sess.run(outputs + [state],
+ feed_dict={inputs[0]: input_value})
+ for v in values:
+ self.assertTrue(np.all(np.isfinite(v)))
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/grid_rnn/python/ops/__init__.py b/tensorflow/contrib/grid_rnn/python/ops/__init__.py
new file mode 100644
index 0000000000..94872f4f0c
--- /dev/null
+++ b/tensorflow/contrib/grid_rnn/python/ops/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
new file mode 100644
index 0000000000..9c75b4ae2f
--- /dev/null
+++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
@@ -0,0 +1,352 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Module for constructing GridRNN cells"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops import rnn_cell
+from tensorflow.contrib import layers
+
+
+class GridRNNCell(rnn_cell.RNNCell):
+ """Grid recurrent cell.
+
+ This implementation is based on:
+
+ http://arxiv.org/pdf/1507.01526v3.pdf
+
+ This is the generic implementation of GridRNN. Users can specify arbitrary number of dimensions,
+ set some of them to be priority (section 3.2), non-recurrent (section 3.3)
+ and input/output dimensions (section 3.4).
+ Weight sharing can also be specified using the `tied` parameter.
+ Type of recurrent units can be specified via `cell_fn`.
+ """
+
+ def __init__(self, num_units, num_dims=1, input_dims=None, output_dims=None, priority_dims=None,
+ non_recurrent_dims=None, tied=False, cell_fn=None, non_recurrent_fn=None):
+ """Initialize the parameters of a Grid RNN cell
+
+ Args:
+ num_units: int, The number of units in all dimensions of this GridRNN cell
+ num_dims: int, Number of dimensions of this grid.
+ input_dims: int or list, List of dimensions which will receive input data.
+ output_dims: int or list, List of dimensions from which the output will be recorded.
+ priority_dims: int or list, List of dimensions to be considered as priority dimensions.
+ If None, no dimension is prioritized.
+ non_recurrent_dims: int or list, List of dimensions that are not recurrent.
+ The transfer function for non-recurrent dimensions is specified via `non_recurrent_fn`,
+ which is default to be `tensorflow.nn.relu`.
+ tied: bool, Whether to share the weights among the dimensions of this GridRNN cell.
+ If there are non-recurrent dimensions in the grid, weights are shared between each
+ group of recurrent and non-recurrent dimensions.
+ cell_fn: function, a function which returns the recurrent cell object. Has to be in the following signature:
+ def cell_func(num_units, input_size):
+ # ...
+
+ and returns an object of type `RNNCell`. If None, LSTMCell with default parameters will be used.
+ non_recurrent_fn: a tensorflow Op that will be the transfer function of the non-recurrent dimensions
+ """
+ if num_dims < 1:
+ raise ValueError('dims must be >= 1: {}'.format(num_dims))
+
+ self._config = _parse_rnn_config(num_dims, input_dims, output_dims, priority_dims,
+ non_recurrent_dims, non_recurrent_fn or nn.relu, tied, num_units)
+
+ cell_input_size = (self._config.num_dims - 1) * num_units
+ if cell_fn is None:
+ self._cell = rnn_cell.LSTMCell(num_units=num_units, input_size=cell_input_size)
+ else:
+ self._cell = cell_fn(num_units, cell_input_size)
+ if not isinstance(self._cell, rnn_cell.RNNCell):
+ raise ValueError('cell_fn must return an object of type RNNCell')
+
+ @property
+ def input_size(self):
+ # temporarily using num_units as the input_size of each dimension.
+ # The actual input size only determined when this cell get invoked,
+ # so this information can be considered unreliable.
+ return self._config.num_units * len(self._config.inputs)
+
+ @property
+ def output_size(self):
+ return self._cell.output_size * len(self._config.outputs)
+
+ @property
+ def state_size(self):
+ return self._cell.state_size * len(self._config.recurrents)
+
+ def __call__(self, inputs, state, scope=None):
+ """Run one step of GridRNN.
+
+ Args:
+ inputs: input Tensor, 2D, batch x input_size. Or None
+ state: state Tensor, 2D, batch x state_size. Note that state_size = cell_state_size * recurrent_dims
+ scope: VariableScope for the created subgraph; defaults to "GridRNNCell".
+
+ Returns:
+ A tuple containing:
+ - A 2D, batch x output_size, Tensor representing the output of the cell
+ after reading "inputs" when previous state was "state".
+ - A 2D, batch x state_size, Tensor representing the new state of the cell
+ after reading "inputs" when previous state was "state".
+ """
+ state_sz = state.get_shape().as_list()[1]
+ if self.state_size != state_sz:
+ raise ValueError('Actual state size not same as specified: {} vs {}.'.format(state_sz, self.state_size))
+
+ conf = self._config
+ dtype = inputs.dtype if inputs is not None else state.dtype
+
+ # c_prev is `m`, and m_prev is `h` in the paper. Keep c and m here for consistency with the codebase
+ c_prev = [None] * self._config.num_dims
+ m_prev = [None] * self._config.num_dims
+ cell_output_size = self._cell.state_size - conf.num_units
+
+ # for LSTM : state = memory cell + output, hence cell_output_size > 0
+ # for GRU/RNN: state = output (whose size is equal to _num_units), hence cell_output_size = 0
+ for recurrent_dim, start_idx in zip(self._config.recurrents, range(0, self.state_size, self._cell.state_size)):
+ if cell_output_size > 0:
+ c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units])
+ m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx + conf.num_units], [-1, cell_output_size])
+ else:
+ m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units])
+
+ new_output = [None] * conf.num_dims
+ new_state = [None] * conf.num_dims
+
+ with vs.variable_scope(scope or type(self).__name__): # GridRNNCell
+
+ # project input
+ if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(conf.inputs) > 0:
+ input_splits = array_ops.split(1, len(conf.inputs), inputs)
+ input_sz = input_splits[0].get_shape().as_list()[1]
+
+ for i, j in enumerate(conf.inputs):
+ input_project_m = vs.get_variable('project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
+ m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
+
+ if cell_output_size > 0:
+ input_project_c = vs.get_variable('project_c_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
+ c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
+
+
+ _propagate(conf.non_priority, conf, self._cell, c_prev, m_prev, new_output, new_state, True)
+ _propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output, new_state, False)
+
+ output_tensors = [new_output[i] for i in self._config.outputs]
+ output = array_ops.zeros([0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(1,
+ output_tensors)
+
+ state_tensors = [new_state[i] for i in self._config.recurrents]
+ states = array_ops.zeros([0, 0], dtype) if len(state_tensors) == 0 else array_ops.concat(1, state_tensors)
+
+ return output, states
+
+
+"""
+Specialized cells, for convenience
+"""
+
+class Grid1BasicRNNCell(GridRNNCell):
+ """1D BasicRNN cell"""
+ def __init__(self, num_units):
+ super(Grid1BasicRNNCell, self).__init__(num_units=num_units, num_dims=1,
+ input_dims=0, output_dims=0, priority_dims=0, tied=False,
+ cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i))
+
+
+class Grid2BasicRNNCell(GridRNNCell):
+ """2D BasicRNN cell
+ This creates a 2D cell which receives input and gives output in the first dimension.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
+ """
+ def __init__(self, num_units, tied=False, non_recurrent_fn=None):
+ super(Grid2BasicRNNCell, self).__init__(num_units=num_units, num_dims=2,
+ input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i),
+ non_recurrent_fn=non_recurrent_fn)
+
+
+class Grid1BasicLSTMCell(GridRNNCell):
+ """1D BasicLSTM cell"""
+ def __init__(self, num_units, forget_bias=1):
+ super(Grid1BasicLSTMCell, self).__init__(num_units=num_units, num_dims=1,
+ input_dims=0, output_dims=0, priority_dims=0, tied=False,
+ cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(num_units=n,
+ forget_bias=forget_bias, input_size=i))
+
+
+class Grid2BasicLSTMCell(GridRNNCell):
+ """2D BasicLSTM cell
+ This creates a 2D cell which receives input and gives output in the first dimension.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
+ """
+ def __init__(self, num_units, tied=False, non_recurrent_fn=None, forget_bias=1):
+ super(Grid2BasicLSTMCell, self).__init__(num_units=num_units, num_dims=2,
+ input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
+ num_units=n, forget_bias=forget_bias, input_size=i),
+ non_recurrent_fn=non_recurrent_fn)
+
+
+class Grid1LSTMCell(GridRNNCell):
+ """1D LSTM cell
+ This is different from Grid1BasicLSTMCell because it gives options to specify the forget bias and enabling peepholes
+ """
+ def __init__(self, num_units, use_peepholes=False, forget_bias=1.0):
+ super(Grid1LSTMCell, self).__init__(num_units=num_units, num_dims=1,
+ input_dims=0, output_dims=0, priority_dims=0,
+ cell_fn=lambda n, i: rnn_cell.LSTMCell(
+ num_units=n, input_size=i, use_peepholes=use_peepholes,
+ forget_bias=forget_bias))
+
+
+class Grid2LSTMCell(GridRNNCell):
+ """2D LSTM cell
+ This creates a 2D cell which receives input and gives output in the first dimension.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
+ """
+ def __init__(self, num_units, tied=False, non_recurrent_fn=None,
+ use_peepholes=False, forget_bias=1.0):
+ super(Grid2LSTMCell, self).__init__(num_units=num_units, num_dims=2,
+ input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n, i: rnn_cell.LSTMCell(
+ num_units=n, input_size=i, forget_bias=forget_bias,
+ use_peepholes=use_peepholes),
+ non_recurrent_fn=non_recurrent_fn)
+
+
+class Grid3LSTMCell(GridRNNCell):
+ """3D BasicLSTM cell
+ This creates a 2D cell which receives input and gives output in the first dimension.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
+ The second and third dimensions are LSTM.
+ """
+ def __init__(self, num_units, tied=False, non_recurrent_fn=None,
+ use_peepholes=False, forget_bias=1.0):
+ super(Grid3LSTMCell, self).__init__(num_units=num_units, num_dims=3,
+ input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n, i: rnn_cell.LSTMCell(
+ num_units=n, input_size=i, forget_bias=forget_bias,
+ use_peepholes=use_peepholes),
+ non_recurrent_fn=non_recurrent_fn)
+
+class Grid2GRUCell(GridRNNCell):
+ """2D LSTM cell
+ This creates a 2D cell which receives input and gives output in the first dimension.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
+ """
+
+ def __init__(self, num_units, tied=False, non_recurrent_fn=None):
+ super(Grid2GRUCell, self).__init__(num_units=num_units, num_dims=2,
+ input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i),
+ non_recurrent_fn=non_recurrent_fn)
+
+"""
+Helpers
+"""
+
+_GridRNNDimension = namedtuple('_GridRNNDimension', ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
+
+_GridRNNConfig = namedtuple('_GridRNNConfig', ['num_dims', 'dims',
+ 'inputs', 'outputs', 'recurrents',
+ 'priority', 'non_priority', 'tied', 'num_units'])
+
+
+def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, ls_non_recurrent_dims,
+ non_recurrent_fn, tied, num_units):
+ def check_dim_list(ls):
+ if ls is None:
+ ls = []
+ if not isinstance(ls, (list, tuple)):
+ ls = [ls]
+ ls = sorted(set(ls))
+ if any(_ < 0 or _ >= num_dims for _ in ls):
+ raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls, num_dims))
+ return ls
+
+ input_dims = check_dim_list(ls_input_dims)
+ output_dims = check_dim_list(ls_output_dims)
+ priority_dims = check_dim_list(ls_priority_dims)
+ non_recurrent_dims = check_dim_list(ls_non_recurrent_dims)
+
+ rnn_dims = []
+ for i in range(num_dims):
+ rnn_dims.append(_GridRNNDimension(idx=i, is_input=(i in input_dims), is_output=(i in output_dims),
+ is_priority=(i in priority_dims),
+ non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else None))
+ return _GridRNNConfig(num_dims=num_dims, dims=rnn_dims, inputs=input_dims, outputs=output_dims,
+ recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
+ priority=priority_dims,
+ non_priority=[x for x in range(num_dims) if x not in priority_dims],
+ tied=tied, num_units=num_units)
+
+
+def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, first_call):
+ """
+ Propagates through all the cells in dim_indices dimensions.
+ """
+ if len(dim_indices) == 0:
+ return
+
+ # Because of the way RNNCells are implemented, we take the last dimension (H_{N-1}) out
+ # and feed it as the state of the RNN cell (in `last_dim_output`)
+ # The input of the cell (H_0 to H_{N-2}) are concatenated into `cell_inputs`
+ if conf.num_dims > 1:
+ ls_cell_inputs = [None] * (conf.num_dims - 1)
+ for d in conf.dims[:-1]:
+ ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[d.idx] is not None else m_prev[d.idx]
+ cell_inputs = array_ops.concat(1, ls_cell_inputs)
+ else:
+ cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], m_prev[0].dtype)
+
+ last_dim_output = new_output[-1] if new_output[-1] is not None else m_prev[-1]
+
+ for i in dim_indices:
+ d = conf.dims[i]
+ if d.non_recurrent_fn:
+ linear_args = array_ops.concat(1, [cell_inputs, last_dim_output]) if conf.num_dims > 1 else last_dim_output
+ with vs.variable_scope('non_recurrent' if conf.tied else 'non_recurrent/cell_{}'.format(i)):
+ if conf.tied and not(first_call and i == dim_indices[0]):
+ vs.get_variable_scope().reuse_variables()
+ new_output[d.idx] = layers.fully_connected(linear_args, num_output_units=conf.num_units,
+ activation_fn=d.non_recurrent_fn,
+ weight_init=vs.get_variable_scope().initializer or
+ layers.initializers.xavier_initializer)
+ else:
+ if c_prev[i] is not None:
+ cell_state = array_ops.concat(1, [c_prev[i], last_dim_output])
+ else:
+ # for GRU/RNN, the state is just the previous output
+ cell_state = last_dim_output
+
+ with vs.variable_scope('recurrent' if conf.tied else 'recurrent/cell_{}'.format(i)):
+ if conf.tied and not (first_call and i == dim_indices[0]):
+ vs.get_variable_scope().reuse_variables()
+ new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state)