diff options
author | 2016-04-18 17:56:51 -0800 | |
---|---|---|
committer | 2016-04-18 19:03:29 -0700 | |
commit | 5c9bc51857bc0c330d3ab976871ee3509647d1e7 (patch) | |
tree | a58def7cbf316c6e091b3b36657f120f1388ec54 /tensorflow/contrib/grid_rnn | |
parent | fc432e37a7ddd408ff09a7b90b1c4cd5af1b134e (diff) |
Merge changes from github.
Change: 120185825
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r-- | tensorflow/contrib/grid_rnn/BUILD | 39 | ||||
-rw-r--r-- | tensorflow/contrib/grid_rnn/__init__.py | 27 | ||||
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/__init__.py | 18 | ||||
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py | 489 | ||||
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/ops/__init__.py | 18 | ||||
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py | 352 |
6 files changed, 943 insertions, 0 deletions
diff --git a/tensorflow/contrib/grid_rnn/BUILD b/tensorflow/contrib/grid_rnn/BUILD new file mode 100644 index 0000000000..c3b9b5a9dd --- /dev/null +++ b/tensorflow/contrib/grid_rnn/BUILD @@ -0,0 +1,39 @@ +# Description: +# Contains classes to construct GridRNN cells +# APIs here are meant to evolve over time. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_tests") + +py_library( + name = "grid_rnn_py", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + srcs_version = "PY2AND3", +) + +cuda_py_tests( + name = "grid_rnn_test", + srcs = ["python/kernel_tests/grid_rnn_test.py"], + additional_deps = [ + ":grid_rnn_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/grid_rnn/__init__.py b/tensorflow/contrib/grid_rnn/__init__.py new file mode 100644 index 0000000000..fbe380f1f3 --- /dev/null +++ b/tensorflow/contrib/grid_rnn/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""GridRNN cells + +## This package provides classes for GridRNN + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.contrib.grid_rnn.python.ops.grid_rnn_cell import * diff --git a/tensorflow/contrib/grid_rnn/python/__init__.py b/tensorflow/contrib/grid_rnn/python/__init__.py new file mode 100644 index 0000000000..94872f4f0c --- /dev/null +++ b/tensorflow/contrib/grid_rnn/python/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py new file mode 100644 index 0000000000..d79cafe944 --- /dev/null +++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py @@ -0,0 +1,489 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for GridRNN cells.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +class GridRNNCellTest(tf.test.TestCase): + + def testGrid2BasicLSTMCell(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)) as root_scope: + x = tf.zeros([1, 3]) + m = tf.zeros([1, 8]) + cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2) + self.assertEqual(cell.state_size, 8) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (1, 2)) + self.assertEqual(s.get_shape(), (1, 8)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 8)) + self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]]) + self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181, + 0.72320831, 0.80555487, 0.39102408, 0.42150158]]) + + # emulate a loop through the input sequence, where we call cell() multiple times + root_scope.reuse_variables() + g2, s2 = cell(x, m) + self.assertEqual(g2.get_shape(), (1, 2)) + self.assertEqual(s2.get_shape(), (1, 8)) + + res = sess.run([g2, s2], {x: np.array([[2., 2., 2.]]), m: res[1]}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 8)) + self.assertAllClose(res[0], [[0.58847463, 0.58847463]]) + self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463, 0.58847463, + 0.97726452, 1.04626071, 0.4927212, 0.51137757]]) + + def testGrid2BasicLSTMCellTied(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)): + x = tf.zeros([1, 3]) + m = tf.zeros([1, 8]) + cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2, tied=True) + self.assertEqual(cell.state_size, 8) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (1, 2)) + self.assertEqual(s.get_shape(), (1, 8)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 8)) + self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]]) + self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181, + 0.72320831, 0.80555487, 0.39102408, 0.42150158]]) + + res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res[1]}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 8)) + self.assertAllClose(res[0], [[0.36703536, 0.36703536]]) + self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536, 0.36703536, + 0.80941606, 0.87550586, 0.40108523, 0.42199609]]) + + def testGrid2BasicLSTMCellWithRelu(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)): + x = tf.zeros([1, 3]) + m = tf.zeros([1, 4]) + cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2, tied=False, non_recurrent_fn=tf.nn.relu) + self.assertEqual(cell.state_size, 4) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (1, 2)) + self.assertEqual(s.get_shape(), (1, 4)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4]])}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 4)) + self.assertAllClose(res[0], [[ 0.31667367, 0.31667367]]) + self.assertAllClose(res[1], [[ 0.29530135, 0.37520045, 0.17044567, 0.21292259]]) + + """ + LSTMCell + """ + + def testGrid2LSTMCell(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 3]) + m = tf.zeros([1, 8]) + cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, use_peepholes=True) + self.assertEqual(cell.state_size, 8) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (1, 2)) + self.assertEqual(s.get_shape(), (1, 8)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 8)) + self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]]) + self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918, + 1.38917875, 1.49043763, 0.83884692, 0.86036491]]) + + def testGrid2LSTMCellTied(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 3]) + m = tf.zeros([1, 8]) + cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, tied=True, use_peepholes=True) + self.assertEqual(cell.state_size, 8) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (1, 2)) + self.assertEqual(s.get_shape(), (1, 8)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 8)) + self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]]) + self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918, + 1.38917875, 1.49043763, 0.83884692, 0.86036491]]) + + def testGrid2LSTMCellWithRelu(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 3]) + m = tf.zeros([1, 4]) + cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, use_peepholes=True, non_recurrent_fn=tf.nn.relu) + self.assertEqual(cell.state_size, 4) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (1, 2)) + self.assertEqual(s.get_shape(), (1, 4)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4]])}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 4)) + self.assertAllClose(res[0], [[ 2.1831727, 2.1831727]]) + self.assertAllClose(res[1], [[ 0.92270052, 1.02325559, 0.66159075, 0.70475441]]) + + """ + RNNCell + """ + + def testGrid2BasicRNNCell(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + x = tf.zeros([2, 2]) + m = tf.zeros([2, 4]) + cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2) + self.assertEqual(cell.state_size, 4) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (2, 2)) + self.assertEqual(s.get_shape(), (2, 4)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]), + m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])}) + self.assertEqual(res[0].shape, (2, 2)) + self.assertEqual(res[1].shape, (2, 4)) + self.assertAllClose(res[0], [[0.94685763, 0.94685763], + [0.99480951, 0.99480951]]) + self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908], + [0.99480951, 0.99480951, 0.97574311, 0.97574311]]) + + def testGrid2BasicRNNCellTied(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + x = tf.zeros([2, 2]) + m = tf.zeros([2, 4]) + cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2, tied=True) + self.assertEqual(cell.state_size, 4) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (2, 2)) + self.assertEqual(s.get_shape(), (2, 4)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]), + m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])}) + self.assertEqual(res[0].shape, (2, 2)) + self.assertEqual(res[1].shape, (2, 4)) + self.assertAllClose(res[0], [[0.94685763, 0.94685763], + [0.99480951, 0.99480951]]) + self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908], + [0.99480951, 0.99480951, 0.97574311, 0.97574311]]) + + def testGrid2BasicRNNCellWithRelu(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 2]) + m = tf.zeros([1, 2]) + cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2, non_recurrent_fn=tf.nn.relu) + self.assertEqual(cell.state_size, 2) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (1, 2)) + self.assertEqual(s.get_shape(), (1, 2)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1.]]), m: np.array([[0.1, 0.1]])}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 2)) + self.assertAllClose(res[0], [[1.80049896, 1.80049896]]) + self.assertAllClose(res[1], [[0.80049896, 0.80049896]]) + + """ + 1-LSTM + """ + + def testGrid1LSTMCell(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)) as root_scope: + x = tf.zeros([1, 3]) + m = tf.zeros([1, 4]) + cell = tf.contrib.grid_rnn.Grid1LSTMCell(2, use_peepholes=True) + self.assertEqual(cell.state_size, 4) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (1, 2)) + self.assertEqual(s.get_shape(), (1, 4)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4]])}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 4)) + self.assertAllClose(res[0], [[0.91287315, 0.91287315]]) + self.assertAllClose(res[1], [[2.26285243, 2.26285243, 0.91287315, 0.91287315]]) + + root_scope.reuse_variables() + + x2 = tf.zeros([0, 0]) + g2, s2 = cell(x2, m) + self.assertEqual(g2.get_shape(), (1, 2)) + self.assertEqual(s2.get_shape(), (1, 4)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g2, s2], {m: res[1]}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 4)) + self.assertAllClose(res[0], [[0.9032144, 0.9032144]]) + self.assertAllClose(res[1], [[2.79966092, 2.79966092, 0.9032144, 0.9032144]]) + + g3, s3 = cell(x2, m) + self.assertEqual(g3.get_shape(), (1, 2)) + self.assertEqual(s3.get_shape(), (1, 4)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g3, s3], {m: res[1]}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 4)) + self.assertAllClose(res[0], [[0.92727238, 0.92727238]]) + self.assertAllClose(res[1], [[3.3529923, 3.3529923, 0.92727238, 0.92727238]]) + + """ + 3-LSTM + """ + def testGrid3LSTMCell(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 3]) + m = tf.zeros([1, 12]) + cell = tf.contrib.grid_rnn.Grid3LSTMCell(2, use_peepholes=True) + self.assertEqual(cell.state_size, 12) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (1, 2)) + self.assertEqual(s.get_shape(), (1, 12)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.1, -0.2, -0.3, -0.4]])}) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 12)) + + self.assertAllClose(res[0], [[0.96892911, 0.96892911]]) + self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911, 0.96892911, + 1.33592629, 1.4373529, 0.80867189, 0.83247656, + 0.7317788, 0.63205892, 0.56548983, 0.50446129]]) + + """ + Edge cases + """ + def testGridRNNEdgeCasesLikeRelu(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + x = tf.zeros([3, 2]) + m = tf.zeros([0, 0]) + + # this is equivalent to relu + cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=1, input_dims=0, output_dims=0, + non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu) + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (3, 2)) + self.assertEqual(s.get_shape(), (0, 0)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., -1.], [-2, 1], [2, -1]])}) + self.assertEqual(res[0].shape, (3, 2)) + self.assertEqual(res[1].shape, (0, 0)) + self.assertAllClose(res[0], [[0, 0], [0, 0], [0.5, 0.5]]) + + def testGridRNNEdgeCasesNoOutput(self): + with self.test_session() as sess: + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 2]) + m = tf.zeros([1, 4]) + + # This cell produces no output + cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=2, input_dims=0, output_dims=None, + non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu) + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (0, 0)) + self.assertEqual(s.get_shape(), (1, 4)) + + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, s], {x: np.array([[1., 1.]]), + m: np.array([[0.1, 0.1, 0.1, 0.1]])}) + self.assertEqual(res[0].shape, (0, 0)) + self.assertEqual(res[1].shape, (1, 4)) + + """ + Test with tf.nn.rnn + """ + + def testGrid2LSTMCellWithRNN(self): + batch_size = 3 + input_size = 5 + max_length = 6 # unrolled up to this length + num_units = 2 + + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units) + + inputs = max_length * [ + tf.placeholder(tf.float32, shape=(batch_size, input_size))] + + outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) + + self.assertEqual(len(outputs), len(inputs)) + self.assertEqual(state.get_shape(), (batch_size, 8)) + + for out, inp in zip(outputs, inputs): + self.assertEqual(out.get_shape()[0], inp.get_shape()[0]) + self.assertEqual(out.get_shape()[1], num_units) + self.assertEqual(out.dtype, inp.dtype) + + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + + input_value = np.ones((batch_size, input_size)) + values = sess.run(outputs + [state], + feed_dict={inputs[0]: input_value}) + for v in values: + self.assertTrue(np.all(np.isfinite(v))) + + def testGrid2LSTMCellReLUWithRNN(self): + batch_size = 3 + input_size = 5 + max_length = 6 # unrolled up to this length + num_units = 2 + + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu) + + inputs = max_length * [ + tf.placeholder(tf.float32, shape=(batch_size, input_size))] + + outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) + + self.assertEqual(len(outputs), len(inputs)) + self.assertEqual(state.get_shape(), (batch_size, 4)) + + for out, inp in zip(outputs, inputs): + self.assertEqual(out.get_shape()[0], inp.get_shape()[0]) + self.assertEqual(out.get_shape()[1], num_units) + self.assertEqual(out.dtype, inp.dtype) + + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + + input_value = np.ones((batch_size, input_size)) + values = sess.run(outputs + [state], + feed_dict={inputs[0]: input_value}) + for v in values: + self.assertTrue(np.all(np.isfinite(v))) + + def testGrid3LSTMCellReLUWithRNN(self): + batch_size = 3 + input_size = 5 + max_length = 6 # unrolled up to this length + num_units = 2 + + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + cell = tf.contrib.grid_rnn.Grid3LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu) + + inputs = max_length * [ + tf.placeholder(tf.float32, shape=(batch_size, input_size))] + + outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) + + self.assertEqual(len(outputs), len(inputs)) + self.assertEqual(state.get_shape(), (batch_size, 8)) + + for out, inp in zip(outputs, inputs): + self.assertEqual(out.get_shape()[0], inp.get_shape()[0]) + self.assertEqual(out.get_shape()[1], num_units) + self.assertEqual(out.dtype, inp.dtype) + + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + + input_value = np.ones((batch_size, input_size)) + values = sess.run(outputs + [state], + feed_dict={inputs[0]: input_value}) + for v in values: + self.assertTrue(np.all(np.isfinite(v))) + + + def testGrid1LSTMCellWithRNN(self): + batch_size = 3 + input_size = 5 + max_length = 6 # unrolled up to this length + num_units = 2 + + with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): + cell = tf.contrib.grid_rnn.Grid1LSTMCell(num_units=num_units) + + # for 1-LSTM, we only feed the first step + inputs = [tf.placeholder(tf.float32, shape=(batch_size, input_size))] \ + + (max_length - 1) * [tf.zeros([0, 0])] + + outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) + + self.assertEqual(len(outputs), len(inputs)) + self.assertEqual(state.get_shape(), (batch_size, 4)) + + for out, inp in zip(outputs, inputs): + self.assertEqual(out.get_shape(), (3, num_units)) + self.assertEqual(out.dtype, inp.dtype) + + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + + input_value = np.ones((batch_size, input_size)) + values = sess.run(outputs + [state], + feed_dict={inputs[0]: input_value}) + for v in values: + self.assertTrue(np.all(np.isfinite(v))) + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/grid_rnn/python/ops/__init__.py b/tensorflow/contrib/grid_rnn/python/ops/__init__.py new file mode 100644 index 0000000000..94872f4f0c --- /dev/null +++ b/tensorflow/contrib/grid_rnn/python/ops/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py new file mode 100644 index 0000000000..9c75b4ae2f --- /dev/null +++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py @@ -0,0 +1,352 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Module for constructing GridRNN cells""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import rnn_cell +from tensorflow.contrib import layers + + +class GridRNNCell(rnn_cell.RNNCell): + """Grid recurrent cell. + + This implementation is based on: + + http://arxiv.org/pdf/1507.01526v3.pdf + + This is the generic implementation of GridRNN. Users can specify arbitrary number of dimensions, + set some of them to be priority (section 3.2), non-recurrent (section 3.3) + and input/output dimensions (section 3.4). + Weight sharing can also be specified using the `tied` parameter. + Type of recurrent units can be specified via `cell_fn`. + """ + + def __init__(self, num_units, num_dims=1, input_dims=None, output_dims=None, priority_dims=None, + non_recurrent_dims=None, tied=False, cell_fn=None, non_recurrent_fn=None): + """Initialize the parameters of a Grid RNN cell + + Args: + num_units: int, The number of units in all dimensions of this GridRNN cell + num_dims: int, Number of dimensions of this grid. + input_dims: int or list, List of dimensions which will receive input data. + output_dims: int or list, List of dimensions from which the output will be recorded. + priority_dims: int or list, List of dimensions to be considered as priority dimensions. + If None, no dimension is prioritized. + non_recurrent_dims: int or list, List of dimensions that are not recurrent. + The transfer function for non-recurrent dimensions is specified via `non_recurrent_fn`, + which is default to be `tensorflow.nn.relu`. + tied: bool, Whether to share the weights among the dimensions of this GridRNN cell. + If there are non-recurrent dimensions in the grid, weights are shared between each + group of recurrent and non-recurrent dimensions. + cell_fn: function, a function which returns the recurrent cell object. Has to be in the following signature: + def cell_func(num_units, input_size): + # ... + + and returns an object of type `RNNCell`. If None, LSTMCell with default parameters will be used. + non_recurrent_fn: a tensorflow Op that will be the transfer function of the non-recurrent dimensions + """ + if num_dims < 1: + raise ValueError('dims must be >= 1: {}'.format(num_dims)) + + self._config = _parse_rnn_config(num_dims, input_dims, output_dims, priority_dims, + non_recurrent_dims, non_recurrent_fn or nn.relu, tied, num_units) + + cell_input_size = (self._config.num_dims - 1) * num_units + if cell_fn is None: + self._cell = rnn_cell.LSTMCell(num_units=num_units, input_size=cell_input_size) + else: + self._cell = cell_fn(num_units, cell_input_size) + if not isinstance(self._cell, rnn_cell.RNNCell): + raise ValueError('cell_fn must return an object of type RNNCell') + + @property + def input_size(self): + # temporarily using num_units as the input_size of each dimension. + # The actual input size only determined when this cell get invoked, + # so this information can be considered unreliable. + return self._config.num_units * len(self._config.inputs) + + @property + def output_size(self): + return self._cell.output_size * len(self._config.outputs) + + @property + def state_size(self): + return self._cell.state_size * len(self._config.recurrents) + + def __call__(self, inputs, state, scope=None): + """Run one step of GridRNN. + + Args: + inputs: input Tensor, 2D, batch x input_size. Or None + state: state Tensor, 2D, batch x state_size. Note that state_size = cell_state_size * recurrent_dims + scope: VariableScope for the created subgraph; defaults to "GridRNNCell". + + Returns: + A tuple containing: + - A 2D, batch x output_size, Tensor representing the output of the cell + after reading "inputs" when previous state was "state". + - A 2D, batch x state_size, Tensor representing the new state of the cell + after reading "inputs" when previous state was "state". + """ + state_sz = state.get_shape().as_list()[1] + if self.state_size != state_sz: + raise ValueError('Actual state size not same as specified: {} vs {}.'.format(state_sz, self.state_size)) + + conf = self._config + dtype = inputs.dtype if inputs is not None else state.dtype + + # c_prev is `m`, and m_prev is `h` in the paper. Keep c and m here for consistency with the codebase + c_prev = [None] * self._config.num_dims + m_prev = [None] * self._config.num_dims + cell_output_size = self._cell.state_size - conf.num_units + + # for LSTM : state = memory cell + output, hence cell_output_size > 0 + # for GRU/RNN: state = output (whose size is equal to _num_units), hence cell_output_size = 0 + for recurrent_dim, start_idx in zip(self._config.recurrents, range(0, self.state_size, self._cell.state_size)): + if cell_output_size > 0: + c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units]) + m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx + conf.num_units], [-1, cell_output_size]) + else: + m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units]) + + new_output = [None] * conf.num_dims + new_state = [None] * conf.num_dims + + with vs.variable_scope(scope or type(self).__name__): # GridRNNCell + + # project input + if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(conf.inputs) > 0: + input_splits = array_ops.split(1, len(conf.inputs), inputs) + input_sz = input_splits[0].get_shape().as_list()[1] + + for i, j in enumerate(conf.inputs): + input_project_m = vs.get_variable('project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype) + m_prev[j] = math_ops.matmul(input_splits[i], input_project_m) + + if cell_output_size > 0: + input_project_c = vs.get_variable('project_c_{}'.format(j), [input_sz, conf.num_units], dtype=dtype) + c_prev[j] = math_ops.matmul(input_splits[i], input_project_c) + + + _propagate(conf.non_priority, conf, self._cell, c_prev, m_prev, new_output, new_state, True) + _propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output, new_state, False) + + output_tensors = [new_output[i] for i in self._config.outputs] + output = array_ops.zeros([0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(1, + output_tensors) + + state_tensors = [new_state[i] for i in self._config.recurrents] + states = array_ops.zeros([0, 0], dtype) if len(state_tensors) == 0 else array_ops.concat(1, state_tensors) + + return output, states + + +""" +Specialized cells, for convenience +""" + +class Grid1BasicRNNCell(GridRNNCell): + """1D BasicRNN cell""" + def __init__(self, num_units): + super(Grid1BasicRNNCell, self).__init__(num_units=num_units, num_dims=1, + input_dims=0, output_dims=0, priority_dims=0, tied=False, + cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i)) + + +class Grid2BasicRNNCell(GridRNNCell): + """2D BasicRNN cell + This creates a 2D cell which receives input and gives output in the first dimension. + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified. + """ + def __init__(self, num_units, tied=False, non_recurrent_fn=None): + super(Grid2BasicRNNCell, self).__init__(num_units=num_units, num_dims=2, + input_dims=0, output_dims=0, priority_dims=0, tied=tied, + non_recurrent_dims=None if non_recurrent_fn is None else 0, + cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i), + non_recurrent_fn=non_recurrent_fn) + + +class Grid1BasicLSTMCell(GridRNNCell): + """1D BasicLSTM cell""" + def __init__(self, num_units, forget_bias=1): + super(Grid1BasicLSTMCell, self).__init__(num_units=num_units, num_dims=1, + input_dims=0, output_dims=0, priority_dims=0, tied=False, + cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(num_units=n, + forget_bias=forget_bias, input_size=i)) + + +class Grid2BasicLSTMCell(GridRNNCell): + """2D BasicLSTM cell + This creates a 2D cell which receives input and gives output in the first dimension. + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified. + """ + def __init__(self, num_units, tied=False, non_recurrent_fn=None, forget_bias=1): + super(Grid2BasicLSTMCell, self).__init__(num_units=num_units, num_dims=2, + input_dims=0, output_dims=0, priority_dims=0, tied=tied, + non_recurrent_dims=None if non_recurrent_fn is None else 0, + cell_fn=lambda n, i: rnn_cell.BasicLSTMCell( + num_units=n, forget_bias=forget_bias, input_size=i), + non_recurrent_fn=non_recurrent_fn) + + +class Grid1LSTMCell(GridRNNCell): + """1D LSTM cell + This is different from Grid1BasicLSTMCell because it gives options to specify the forget bias and enabling peepholes + """ + def __init__(self, num_units, use_peepholes=False, forget_bias=1.0): + super(Grid1LSTMCell, self).__init__(num_units=num_units, num_dims=1, + input_dims=0, output_dims=0, priority_dims=0, + cell_fn=lambda n, i: rnn_cell.LSTMCell( + num_units=n, input_size=i, use_peepholes=use_peepholes, + forget_bias=forget_bias)) + + +class Grid2LSTMCell(GridRNNCell): + """2D LSTM cell + This creates a 2D cell which receives input and gives output in the first dimension. + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified. + """ + def __init__(self, num_units, tied=False, non_recurrent_fn=None, + use_peepholes=False, forget_bias=1.0): + super(Grid2LSTMCell, self).__init__(num_units=num_units, num_dims=2, + input_dims=0, output_dims=0, priority_dims=0, tied=tied, + non_recurrent_dims=None if non_recurrent_fn is None else 0, + cell_fn=lambda n, i: rnn_cell.LSTMCell( + num_units=n, input_size=i, forget_bias=forget_bias, + use_peepholes=use_peepholes), + non_recurrent_fn=non_recurrent_fn) + + +class Grid3LSTMCell(GridRNNCell): + """3D BasicLSTM cell + This creates a 2D cell which receives input and gives output in the first dimension. + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified. + The second and third dimensions are LSTM. + """ + def __init__(self, num_units, tied=False, non_recurrent_fn=None, + use_peepholes=False, forget_bias=1.0): + super(Grid3LSTMCell, self).__init__(num_units=num_units, num_dims=3, + input_dims=0, output_dims=0, priority_dims=0, tied=tied, + non_recurrent_dims=None if non_recurrent_fn is None else 0, + cell_fn=lambda n, i: rnn_cell.LSTMCell( + num_units=n, input_size=i, forget_bias=forget_bias, + use_peepholes=use_peepholes), + non_recurrent_fn=non_recurrent_fn) + +class Grid2GRUCell(GridRNNCell): + """2D LSTM cell + This creates a 2D cell which receives input and gives output in the first dimension. + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified. + """ + + def __init__(self, num_units, tied=False, non_recurrent_fn=None): + super(Grid2GRUCell, self).__init__(num_units=num_units, num_dims=2, + input_dims=0, output_dims=0, priority_dims=0, tied=tied, + non_recurrent_dims=None if non_recurrent_fn is None else 0, + cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i), + non_recurrent_fn=non_recurrent_fn) + +""" +Helpers +""" + +_GridRNNDimension = namedtuple('_GridRNNDimension', ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn']) + +_GridRNNConfig = namedtuple('_GridRNNConfig', ['num_dims', 'dims', + 'inputs', 'outputs', 'recurrents', + 'priority', 'non_priority', 'tied', 'num_units']) + + +def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, ls_non_recurrent_dims, + non_recurrent_fn, tied, num_units): + def check_dim_list(ls): + if ls is None: + ls = [] + if not isinstance(ls, (list, tuple)): + ls = [ls] + ls = sorted(set(ls)) + if any(_ < 0 or _ >= num_dims for _ in ls): + raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls, num_dims)) + return ls + + input_dims = check_dim_list(ls_input_dims) + output_dims = check_dim_list(ls_output_dims) + priority_dims = check_dim_list(ls_priority_dims) + non_recurrent_dims = check_dim_list(ls_non_recurrent_dims) + + rnn_dims = [] + for i in range(num_dims): + rnn_dims.append(_GridRNNDimension(idx=i, is_input=(i in input_dims), is_output=(i in output_dims), + is_priority=(i in priority_dims), + non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else None)) + return _GridRNNConfig(num_dims=num_dims, dims=rnn_dims, inputs=input_dims, outputs=output_dims, + recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims], + priority=priority_dims, + non_priority=[x for x in range(num_dims) if x not in priority_dims], + tied=tied, num_units=num_units) + + +def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, first_call): + """ + Propagates through all the cells in dim_indices dimensions. + """ + if len(dim_indices) == 0: + return + + # Because of the way RNNCells are implemented, we take the last dimension (H_{N-1}) out + # and feed it as the state of the RNN cell (in `last_dim_output`) + # The input of the cell (H_0 to H_{N-2}) are concatenated into `cell_inputs` + if conf.num_dims > 1: + ls_cell_inputs = [None] * (conf.num_dims - 1) + for d in conf.dims[:-1]: + ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[d.idx] is not None else m_prev[d.idx] + cell_inputs = array_ops.concat(1, ls_cell_inputs) + else: + cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], m_prev[0].dtype) + + last_dim_output = new_output[-1] if new_output[-1] is not None else m_prev[-1] + + for i in dim_indices: + d = conf.dims[i] + if d.non_recurrent_fn: + linear_args = array_ops.concat(1, [cell_inputs, last_dim_output]) if conf.num_dims > 1 else last_dim_output + with vs.variable_scope('non_recurrent' if conf.tied else 'non_recurrent/cell_{}'.format(i)): + if conf.tied and not(first_call and i == dim_indices[0]): + vs.get_variable_scope().reuse_variables() + new_output[d.idx] = layers.fully_connected(linear_args, num_output_units=conf.num_units, + activation_fn=d.non_recurrent_fn, + weight_init=vs.get_variable_scope().initializer or + layers.initializers.xavier_initializer) + else: + if c_prev[i] is not None: + cell_state = array_ops.concat(1, [c_prev[i], last_dim_output]) + else: + # for GRU/RNN, the state is just the previous output + cell_state = last_dim_output + + with vs.variable_scope('recurrent' if conf.tied else 'recurrent/cell_{}'.format(i)): + if conf.tied and not (first_call and i == dim_indices[0]): + vs.get_variable_scope().reuse_variables() + new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state) |