aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/recurrent
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2018-04-09 16:55:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-09 16:57:31 -0700
commit26e36ec2c9fb061e7349b2259bc69b2140d18819 (patch)
tree9a1e2efa757def9f977c45b538e4e2c058073e8a /tensorflow/contrib/recurrent
parent66a601eece46e91c7c19cb22ebe526cf8b2253d5 (diff)
Export recurrent and its RNN implementation in tf.contrib.
PiperOrigin-RevId: 192210794
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r--tensorflow/contrib/recurrent/BUILD106
-rw-r--r--tensorflow/contrib/recurrent/README.md13
-rw-r--r--tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py163
-rw-r--r--tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py192
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py396
-rw-r--r--tensorflow/contrib/recurrent/python/ops/recurrent.py720
-rw-r--r--tensorflow/contrib/recurrent/python/recurrent_api.py29
7 files changed, 1619 insertions, 0 deletions
diff --git a/tensorflow/contrib/recurrent/BUILD b/tensorflow/contrib/recurrent/BUILD
new file mode 100644
index 0000000000..b3cb04ce26
--- /dev/null
+++ b/tensorflow/contrib/recurrent/BUILD
@@ -0,0 +1,106 @@
+# Recurrent library.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+
+py_library(
+ name = "recurrent_py",
+ srcs = ["python/recurrent_api.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":functional_rnn_ops_py",
+ ":recurrent_ops_py",
+ ],
+)
+
+py_library(
+ name = "recurrent_ops_py",
+ srcs = ["python/ops/recurrent.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_library(
+ name = "functional_rnn_ops_py",
+ srcs = ["python/ops/functional_rnn.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":recurrent_ops_py",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:standard_ops",
+ ],
+)
+
+cuda_py_tests(
+ name = "recurrent_ops_test",
+ size = "small",
+ srcs = ["python/kernel_tests/recurrent_test.py"],
+ additional_deps = [
+ ":recurrent_ops_py",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:variables",
+ ],
+ tags = ["nopip"],
+)
+
+cuda_py_tests(
+ name = "functional_rnn_ops_test",
+ size = "small",
+ srcs = ["python/kernel_tests/functional_rnn_test.py"],
+ additional_deps = [
+ ":functional_rnn_ops_py",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/contrib/tpu:tpu",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:rnn",
+ "//tensorflow/python:rnn_cell",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+ tags = ["nopip"],
+)
diff --git a/tensorflow/contrib/recurrent/README.md b/tensorflow/contrib/recurrent/README.md
new file mode 100644
index 0000000000..86e10eee51
--- /dev/null
+++ b/tensorflow/contrib/recurrent/README.md
@@ -0,0 +1,13 @@
+# Recurrent computation library
+
+The recurrent computation library contains code to perform recurrent
+computations.
+
+Its chief application is to implement recurrent neural networks (RNNs, LSTMs,
+etc), which is implemented in `functional_rnn.py`. Similar techniques may be
+used to implement deep networks.
+
+The computation saves the activations in the forward pass, and computes the
+gradients in the backward pass using a single accumulator.
+
+The `functional_rnn` interface is compatible with the `dynamic_rnn` API.
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py
new file mode 100644
index 0000000000..0f19ac7dbe
--- /dev/null
+++ b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py
@@ -0,0 +1,163 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Functional RNN."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+
+from tensorflow.contrib.recurrent.python.ops import functional_rnn
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import rnn as rnn_lib
+from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import variables
+import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
+import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
+from tensorflow.python.platform import test as test_lib
+from tensorflow.python.platform import tf_logging as logging
+
+
+def _CreateStackedLstmCell(*cell_sizes):
+ subcells = [rnn_cell_impl.LSTMCell(cell_size) for cell_size in cell_sizes]
+ return rnn_cell_impl.MultiRNNCell(subcells)
+
+
+class FunctionalRnnTest(test_util.TensorFlowTestCase):
+
+ _BATCH_SIZE = 3
+ _TOTAL_TIME = 5
+ _INPUT_SIZE = 11
+ _NUM_UNITS = 7
+
+ # Set this to some output if you want to use it.
+ _LSTM_GRAPH_DEF_FILEPATH = None
+
+ _CELLDEFS = {
+ 'gru': (rnn_cell_impl.GRUCell, [_NUM_UNITS]),
+ 'lstm': (rnn_cell_impl.LSTMCell, [_NUM_UNITS]),
+ 'stacked_lstm': (_CreateStackedLstmCell, [_NUM_UNITS] * 3)
+ }
+
+ def _CreateCell(self, celldef_name):
+ func, args = self._CELLDEFS[celldef_name]
+ return func(*args)
+
+ def _CreateInputs(self):
+ inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE,
+ FunctionalRnnTest._TOTAL_TIME,
+ FunctionalRnnTest._INPUT_SIZE])
+ # Always leave one time slot empty, to check max_length behavior.
+ sequence_length = np.random.randint(
+ 0, high=FunctionalRnnTest._TOTAL_TIME - 1,
+ size=FunctionalRnnTest._BATCH_SIZE,
+ dtype=np.int)
+ return (inputs, sequence_length)
+
+ def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs,
+ tf_sequence_length, initial_state=None,
+ time_major=None, scope=None):
+ tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs,
+ sequence_length=tf_sequence_length,
+ initial_state=initial_state,
+ dtype=dtypes.float32,
+ time_major=time_major,
+ scope=scope)
+ grad = gradients_impl.gradients(tf_result, variables.trainable_variables())
+ return {'inference': tf_result, 'grad': grad}
+
+ def _MaybeResetVariables(self, variable_cache, sess, var_list):
+ """Possibly resets the variables to a previously seen value."""
+ reset_ops = []
+ fetches = []
+ for var in var_list:
+ if var.name in variable_cache:
+ reset_ops += [var.assign(variable_cache[var.name])]
+ else:
+ fetches += [(var.name, var)]
+ if reset_ops:
+ sess.run(reset_ops)
+ if fetches:
+ val = sess.run(dict(fetches))
+ for n, v in val.items():
+ assert n not in variable_cache
+ variable_cache[n] = v
+
+ def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache,
+ is_dynamic):
+ with ops.Graph().as_default() as graph:
+ tf_inputs = array_ops.placeholder(
+ dtypes.float32, shape=numpy_inputs.shape)
+ tf_slen = array_ops.placeholder(dtypes.int32)
+ feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen}
+ cell = self._CreateCell(cell_name)
+ fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn
+ fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen)
+ with self.test_session(graph=graph) as sess:
+ sess.run(variables.global_variables_initializer())
+ # Note that cell.trainable_variables it not always set.
+ self._MaybeResetVariables(variable_cache, sess,
+ variables.trainable_variables())
+ val = sess.run(fetches, feed_dict=feeds)
+ graph_def = graph.as_graph_def()
+ return graph_def, val
+
+ def testRunLstm(self):
+ """Runs a simple LSTM. Does not check output."""
+ np_inputs, np_slen = self._CreateInputs()
+ var_cache = {}
+ graphdef, _ = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, False)
+ logging.info('graphdef: %s', graphdef)
+ if self._LSTM_GRAPH_DEF_FILEPATH:
+ with open(self._LSTM_GRAPH_DEF_FILEPATH, 'w') as f:
+ f.write(str(graphdef))
+
+ def testLstm(self):
+ """Checks an LSTM against the reference implementation."""
+ np_inputs, np_slen = self._CreateInputs()
+ var_cache = {}
+ _, func_rnn = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, False)
+ _, dyn_rnn = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, True)
+ self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
+ self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
+
+ def testGru(self):
+ """Checks a GRU cell against the reference implementation."""
+ np_inputs, np_slen = self._CreateInputs()
+ var_cache = {}
+ _, func_rnn = self._RunRnn(np_inputs, np_slen, 'gru', var_cache, False)
+ _, dyn_rnn = self._RunRnn(np_inputs, np_slen, 'gru', var_cache, True)
+ self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
+ self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
+
+ def testStackedLstm(self):
+ """Checks a stacked LSTM cell against the reference implementation."""
+ np_inputs, np_slen = self._CreateInputs()
+ var_cache = {}
+ args = [np_inputs, np_slen, 'stacked_lstm', var_cache]
+ _, func_rnn = self._RunRnn(*(args + [False]))
+ _, dyn_rnn = self._RunRnn(*(args + [True]))
+ self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
+ self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
+
+
+if __name__ == '__main__':
+ test_lib.main()
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
new file mode 100644
index 0000000000..00fbd4fbb8
--- /dev/null
+++ b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
@@ -0,0 +1,192 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Recurrent ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.contrib.recurrent.python.ops import recurrent
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test as test_lib
+from tensorflow.python.platform import tf_logging as logging
+
+
+_ElmanState = collections.namedtuple('ElmanState', ('h'))
+_ElmanTheta = collections.namedtuple('ElmanTheta', ('w', 'b'))
+_ElmanInputs = collections.namedtuple('ElmanInputs', ('x'))
+
+
+# TODO(drpng): add test for max length computation.
+class RecurrentTest(test_util.TensorFlowTestCase):
+
+ def testBasic(self):
+ # pylint:disable=invalid-name
+ _PolyState = collections.namedtuple('PolyState', ('value', 'x_power'))
+ _PolyTheta = collections.namedtuple('PolyTheta', ('x'))
+ _PolyInputs = collections.namedtuple('PolyInputs', ('coeff'))
+ # pylint:enable=invalid-name
+
+ def Poly(theta, state, inputs):
+ next_state = _PolyState(
+ value=state.value + inputs.coeff * state.x_power,
+ x_power=state.x_power * theta.x)
+ return next_state, []
+
+ with self.test_session() as sess:
+ theta = _PolyTheta(x=array_ops.constant(2.0))
+ state = _PolyState(
+ value=array_ops.constant(0.0),
+ x_power=array_ops.constant(1.0))
+ inputs = _PolyInputs(coeff=array_ops.constant([1., 2., 3.]))
+
+ # x = 2
+ # 1 + 2*x + 3*x^2
+ ret = recurrent.Recurrent(theta, state, inputs, Poly)
+
+ acc, state = sess.run(ret)
+ self.assertAllClose(acc.value, [1., 5., 17.])
+ self.assertAllClose(acc.x_power, [2., 4., 8.])
+ self.assertAllClose(state.value, 17.)
+ self.assertAllClose(state.x_power, 8.)
+
+ y = ret[1].value
+ dx, d_coeff = gradients_impl.gradients(ys=[y], xs=[theta.x, inputs.coeff])
+ dx_val, d_coeff_val = sess.run([dx, d_coeff])
+
+ # 2 + 6*x
+ self.assertAllClose(dx_val, 14.)
+ self.assertAllClose(d_coeff_val, [1., 2., 4.])
+
+ # acc = [1, 1+2x, 1+2x+3x^2]
+ # sum(acc) = 3 + 4x + 3x^2
+ acc = ret[0].value
+ dx, d_coeff = gradients_impl.gradients(
+ ys=[math_ops.reduce_sum(acc)], xs=[theta.x, inputs.coeff])
+ dx_val, d_coeff_val = sess.run([dx, d_coeff])
+ # 4 + 6*x
+ self.assertAllClose(dx_val, 16.)
+ self.assertAllClose(d_coeff_val, [3., 4., 4.])
+
+ @staticmethod
+ def Rand(shape):
+ return random_ops.random_uniform(
+ shape, minval=-0.2, maxval=0.2, dtype=dtypes.float64)
+
+ @staticmethod
+ def Elman(theta, state0, inputs):
+ h0, w, b, x = state0.h, theta.w, theta.b, inputs.x
+ xw = math_ops.matmul(array_ops.concat([x, h0], axis=1), w)
+ h1 = math_ops.sigmoid(xw + b)
+ state1 = _ElmanState(h=h1)
+ return (state1, state1)
+
+ @staticmethod
+ def ElmanGrad(theta, state0, inputs, extras, dstate1):
+
+ @function.Defun()
+ def Grad(h0, w, b, x, h1, dh1):
+ del b
+ # We hand-roll the gradient for the 2nd half of the cell as a demo.
+ dxwb = (dh1 * (1 - h1) * h1)
+ dxw, db = dxwb, math_ops.reduce_sum(dxwb, axis=0)
+
+ # Uses tf.gradient for the 1nd half of the cell as a demo.
+ xw = math_ops.matmul(array_ops.concat([x, h0], axis=1), w)
+ dh0, dx, dw = gradients_impl.gradients(
+ ys=[xw], xs=[h0, x, w], grad_ys=[dxw])
+
+ return dh0, dx, dw, db
+
+ dh0, dx, dw, db = Grad(state0.h, theta.w, theta.b, inputs.x,
+ extras.h, dstate1.h)
+ dstate0 = _ElmanState(h=dh0)
+ dinputs = _ElmanInputs(x=dx)
+ return (_ElmanTheta(w=dw, b=db), dstate0, dinputs)
+
+ @staticmethod
+ def ElmanOut(state1):
+ return _ElmanState(x=state1.h)
+
+ @staticmethod
+ def ElmanOutGrad(dout):
+ return _ElmanState(h=dout.x)
+
+ def testElman(self):
+ for seqlen, use_grad in [(1, False), (1, True), (7, False), (7, True)]:
+ logging.info('== Elman: seqlen=%s, use_grad=%s', seqlen, use_grad)
+ self._ParameterizedTestElman(seqlen, use_grad)
+
+ def _ParameterizedTestElman(self, seqlen, use_grad):
+
+ with self.test_session() as sess:
+ random_seed.set_random_seed(342462)
+
+ batch = 3
+ dims = 4
+ theta = _ElmanTheta(w=RecurrentTest.Rand([2 * dims, dims]),
+ b=RecurrentTest.Rand([dims]))
+ state0 = _ElmanState(h=RecurrentTest.Rand([batch, dims]))
+ inputs = _ElmanInputs(x=RecurrentTest.Rand([seqlen, batch, dims]))
+
+ # Statically unrolled.
+ s = state0
+ out = []
+ for i in xrange(seqlen):
+ inp = _ElmanInputs(x=inputs.x[i, :])
+ s, _ = RecurrentTest.Elman(theta, s, inp)
+ out += [s.h]
+ acc0, final0 = array_ops.stack(out), s.h
+ loss0 = math_ops.reduce_sum(acc0) + math_ops.reduce_sum(final0)
+ (dw0, db0, dh0, di0) = gradients_impl.gradients(
+ loss0, [theta.w, theta.b, state0.h, inputs.x])
+
+ acc1, final1 = recurrent.Recurrent(
+ theta=theta,
+ state0=state0,
+ inputs=inputs,
+ cell_fn=RecurrentTest.Elman,
+ cell_grad=RecurrentTest.ElmanGrad if use_grad else None)
+ assert isinstance(acc1, _ElmanState)
+ assert isinstance(final1, _ElmanState)
+ acc1, final1 = acc1.h, final1.h
+ loss1 = math_ops.reduce_sum(acc1) + math_ops.reduce_sum(final1)
+ (dw1, db1, dh1, di1) = gradients_impl.gradients(
+ loss1, [theta.w, theta.b, state0.h, inputs.x])
+
+ # Fetches a few values and compare them.
+ (acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0,
+ di1) = sess.run(
+ [acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0, di1])
+ self.assertAllClose(acc0, acc1)
+ self.assertAllClose(final0, final1)
+ self.assertAllClose(dw0, dw1)
+ self.assertAllClose(db0, db1)
+ self.assertAllClose(dh0, dh1)
+ self.assertAllClose(di0, di1)
+
+if __name__ == '__main__':
+ test_lib.main()
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
new file mode 100644
index 0000000000..a085474c1b
--- /dev/null
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -0,0 +1,396 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A tf.nn.dynamic_rnn variant, built on the Recurrent class.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+from tensorflow.contrib.recurrent.python.ops import recurrent
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import nest
+
+
+def _GetDTypesFromStructure(struct):
+ dtypes_list = []
+ for x in nest.flatten(struct):
+ x = ops.convert_to_tensor(x)
+ dtypes_list.append(x.dtype)
+ return dtypes_list
+
+
+def _SetShapeFromTemplate(struct, struct_template):
+ as_list = nest.flatten(struct)
+ template_as_list = nest.flatten(struct_template)
+ for element, template in zip(as_list, template_as_list):
+ element.set_shape(template.shape)
+
+
+class _FunctionalRnnCell(object):
+ """Wrapper around RNNCell which separates state from computation.
+
+ This class accomplishes the following:
+ * Turn the cell's `__call__` function into a pure function. The global
+ side effects are separated as `theta`. They are the variables created
+ for the weights of the computation.
+ * Unless the output is aliased as part of the state, extend the state to
+ contain the output so that we store the history in `Recurrent`.
+ * Set static shapes as required.
+ """
+
+ def __init__(self, rnn_cell, seq_inputs, initial_state):
+ assert initial_state is not None
+
+ # TODO(drpng): Dtype needs to be configurable.
+ input_dtypes = [dtypes.float32] + _GetDTypesFromStructure(initial_state)
+ # See _index.
+ like_inputs_t = nest.map_structure(
+ lambda x: array_ops.stop_gradient(array_ops.gather(x, 0)), seq_inputs)
+ input_structure = (like_inputs_t, initial_state)
+
+ @function.Defun(*input_dtypes)
+ def FlatCellStep(*flat_inputs):
+ """The flattened version of `rnn_cell`."""
+ inputs_t, state0 = nest.pack_sequence_as(input_structure, flat_inputs)
+ _SetShapeFromTemplate(state0, initial_state)
+ _SetShapeFromTemplate(inputs_t, like_inputs_t)
+ outputs_t, state1 = rnn_cell(inputs_t, state0)
+ state_list = nest.flatten(state1)
+ self._output_shape = outputs_t.shape
+
+ if outputs_t in state_list:
+ output_index_in_state = state_list.index(outputs_t)
+ else:
+ output_index_in_state = None
+
+ if output_index_in_state is None:
+ self._prepend_output = True
+ self._output_state_idx = 0
+ return [outputs_t] + state_list
+ else:
+ self._output_state_idx = output_index_in_state
+ self._prepend_output = False
+ # To save memory, we don't store return the output separately
+ # from the state list, since we know it's the same.
+ return state_list
+
+ def _ToPureFunction(func):
+ # NOTE: This forces the creating of the function.
+ if func.captured_inputs:
+ pure_func = copy.copy(func)
+ # pylint: disable=protected-access
+ pure_func._extra_inputs = []
+ return pure_func
+ return func
+
+ pure_flat_cell_step = _ToPureFunction(FlatCellStep)
+
+ def CellStep(theta, extended_state0, inputs_t):
+ """Performs one time steps on structured inputs.
+
+ The purpose of this function is to turn the parameters into flattened
+ versions, and to resolve the parameter order difference between
+ `Recurrent` and `RNNCell`.
+
+ In the event the cell returns a transformed output that is not aliased
+ within its state, the `extended_state0` also contains the output as its
+ first element.
+
+ Args:
+ theta: Weights required for the computation. A structure of tensors.
+ extended_state0: the state0, and possibly the output at the previous
+ time step. A structure of tensors.
+ inputs_t: the inputs at time t.
+
+ Returns:
+ A pair of the next state (inclusive of the output), and an empty list
+ (unused `extras`).
+ The next state is congruent to state0.
+ """
+ extended_state0_flat = nest.flatten(extended_state0)
+ state0_flat = self.MaybeRemoveOutputFromState(extended_state0_flat)
+ full_inputs = [inputs_t] + state0_flat + theta
+ # Note that the thetas are additional inputs appeneded as extra
+ # parameters.
+ cell_out = pure_flat_cell_step(*full_inputs)
+ return cell_out, []
+
+ self._cell_step = CellStep
+ self._theta = FlatCellStep.captured_inputs
+ self._zero_state = rnn_cell.zero_state
+ self._state_template = initial_state
+ self._output_size = rnn_cell.output_size
+
+ @property
+ def extended_initial_state(self):
+ if self._prepend_output:
+ return [array_ops.zeros(self._output_shape), self._state_template]
+ else:
+ # The base case, where the output is just the hidden state.
+ return self._state_template
+
+ @property
+ def cell_step(self):
+ return self._cell_step
+
+ @property
+ def theta(self):
+ return self._theta
+
+ @property
+ def state_template(self):
+ return self._state_template
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def GetOutputFromState(self, state):
+ return nest.flatten(state)[self._output_state_idx]
+
+ def MaybeRemoveOutputFromState(self, flat_state):
+ if self._prepend_output:
+ return flat_state[1:]
+ return flat_state
+
+
+def _ApplyLengthsToBatch(sequence_lengths, tf_output):
+ # TODO(drpng): just use Update so that we don't carry over the gradients?
+ """Sets the output to be zero at the end of the sequence."""
+ # output is batch major.
+ batch_size, max_time, vector_size = tf_output.shape
+ output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
+ output_time = array_ops.reshape(output_time, [batch_size, max_time])
+ lengths = array_ops.tile(
+ array_ops.reshape(sequence_lengths, [-1, 1]), [1, max_time])
+ is_less = math_ops.cast(
+ math_ops.less(output_time, lengths), dtype=dtypes.float32)
+ keep_mask = array_ops.tile(
+ array_ops.expand_dims(is_less, -1),
+ [1, 1, vector_size])
+ final_output = keep_mask * tf_output
+ return final_output
+
+
+def _PickFinalStateFromHistory(acc_state, sequence_length):
+ """Implements acc_state[sequence_length - 1]."""
+ # This will work on all platforms, unlike the regular slice.
+ last_value = []
+ for state_var in nest.flatten(acc_state):
+ # We compute the following with matrix operations:
+ # last_var = state_var[sequence_length - 1]
+ shape = array_ops.shape(state_var)
+ max_time, batch_size = shape[0], shape[1]
+ output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
+ output_time = array_ops.reshape(output_time, [batch_size, max_time])
+ lengths = array_ops.tile(array_ops.reshape(sequence_length,
+ [-1, 1]), [1, max_time])
+ last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1),
+ dtype=dtypes.float32)
+ last_idx = array_ops.transpose(last_idx)
+ last_idx_for_bcast = array_ops.expand_dims(last_idx, -1)
+ sliced = math_ops.multiply(last_idx_for_bcast, state_var)
+ last_var = math_ops.reduce_sum(sliced, 0)
+ last_value += [last_var]
+ return nest.pack_sequence_as(acc_state, last_value)
+
+
+def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
+ total_time, inputs_lengths):
+ """Post-process output of recurrent.
+
+ This function takes the accumulated extended state and extracts the requested
+ state and output.
+
+ When `inputs_lengths` has been set, it extracts the output from the
+ accumulated state. It also sets outputs past.
+
+ It also sets the static shape information.
+
+ Args:
+ extended_acc_state: A structure containing the accumulated state at each
+ time. It may contain the output at each time as well.
+ extended_final_state: A structure containing the final state. It may
+ contain the output at the final time.
+ func_cell: The functional wrapper around the cell.
+ total_time: A scalar integer tensor.
+ inputs_lengths: An integer tensor with one entry per input.
+
+ Returns:
+ A tuple with the outputs at each time, and the final state.
+ """
+ if inputs_lengths is None:
+ flat_final_state = func_cell.MaybeRemoveOutputFromState(
+ nest.flatten(extended_final_state))
+ tf_state = nest.pack_sequence_as(func_cell.state_template, flat_final_state)
+ else:
+ # The accumulated state is over the entire sequence, so we pick it
+ # out from the acc_state sequence.
+ flat_acc_state = func_cell.MaybeRemoveOutputFromState(
+ nest.flatten(extended_acc_state))
+ acc_state = nest.pack_sequence_as(
+ func_cell.state_template, flat_acc_state)
+ tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths)
+
+ output_from_state = func_cell.GetOutputFromState(extended_acc_state)
+ tf_output = array_ops.transpose(output_from_state, [1, 0, 2])
+ tf_output.set_shape(
+ [func_cell.output_shape[0], total_time, func_cell.output_shape[1]])
+ if inputs_lengths is not None:
+ # Need set the outputs to zero.
+ tf_output = _ApplyLengthsToBatch(inputs_lengths, tf_output)
+ # tf_output = array_ops.zeros([4, 3, 5])
+ _SetShapeFromTemplate(tf_state, func_cell.state_template)
+ return tf_output, tf_state
+
+
+# pylint: disable=invalid-name
+def functional_rnn(cell, inputs, sequence_length=None,
+ initial_state=None, dtype=None, time_major=False,
+ scope=None, use_tpu=False):
+ """Same interface as `tf.nn.dynamic_rnn`."""
+ with variable_scope.variable_scope(scope or 'rnn'):
+ if not time_major:
+ inputs = nest.map_structure(
+ lambda t: array_ops.transpose(t, [1, 0, 2]), inputs)
+ inputs_flat = nest.flatten(inputs)
+ batch_size = array_ops.shape(inputs_flat[0])[1]
+ if initial_state is None:
+ initial_state = cell.zero_state(batch_size, dtype)
+ func_cell = _FunctionalRnnCell(cell, inputs, initial_state)
+ extended_acc_state, extended_final_state = recurrent.Recurrent(
+ theta=func_cell.theta,
+ state0=func_cell.extended_initial_state,
+ inputs=inputs,
+ cell_fn=func_cell.cell_step,
+ use_tpu=use_tpu)
+ return _PostProcessOutput(extended_acc_state, extended_final_state,
+ func_cell, inputs_flat[0].shape[0], sequence_length)
+
+
+def bidirectional_functional_rnn(
+ cell_fw,
+ cell_bw,
+ inputs,
+ initial_state_fw=None,
+ initial_state_bw=None,
+ dtype=None,
+ sequence_length=None,
+ time_major=False,
+ use_tpu=False,
+ scope=None):
+ """Creates a bidirectional recurrent neural network.
+
+ Performs fully dynamic unrolling of inputs in both directions. Built to be API
+ compatible with `tf.nn.bidirectional_dynamic_rnn`, but implemented with
+ functional control flow for TPU compatibility.
+
+ Args:
+ cell_fw: An instance of `tf.contrib.rnn.RNNCell`.
+ cell_bw: An instance of `tf.contrib.rnn.RNNCell`.
+ inputs: The RNN inputs. If time_major == False (default), this must be a
+ Tensor (or hierarchical structure of Tensors) of shape
+ [batch_size, max_time, ...]. If time_major == True, this must be a Tensor
+ (or hierarchical structure of Tensors) of shape:
+ [max_time, batch_size, ...]. The first two dimensions must match across
+ all the inputs, but otherwise the ranks and other shape components may
+ differ.
+ initial_state_fw: An optional initial state for `cell_fw`. Should match
+ `cell_fw.zero_state` in structure and type.
+ initial_state_bw: An optional initial state for `cell_bw`. Should match
+ `cell_bw.zero_state` in structure and type.
+ dtype: (optional) The data type for the initial state and expected output.
+ Required if initial_states are not provided or RNN state has a
+ heterogeneous dtype.
+ sequence_length: An optional int32/int64 vector sized [batch_size]. Used to
+ copy-through state and zero-out outputs when past a batch element's
+ sequence length. So it's more for correctness than performance.
+ time_major: Whether the `inputs` tensor is in "time major" format.
+ use_tpu: Whether to enable TPU-compatible operation. If True, does not truly
+ reverse `inputs` in the backwards RNN. Once b/69305369 is fixed, we can
+ remove this flag.
+ scope: An optional scope name for the dynamic RNN.
+
+ Returns:
+ outputs: A tuple of `(output_fw, output_bw)`. The output of the forward and
+ backward RNN. If time_major == False (default), these will
+ be Tensors shaped: [batch_size, max_time, cell.output_size]. If
+ time_major == True, these will be Tensors shaped:
+ [max_time, batch_size, cell.output_size]. Note, if cell.output_size is a
+ (possibly nested) tuple of integers or TensorShape objects, then the
+ output for that direction will be a tuple having the same structure as
+ cell.output_size, containing Tensors having shapes corresponding to the
+ shape data in cell.output_size.
+ final_states: A tuple of `(final_state_fw, final_state_bw)`. A Tensor or
+ hierarchical structure of Tensors indicating the final cell state in each
+ direction. Must have the same structure and shape as cell.zero_state.
+
+ Raises:
+ ValueError: If `initial_state_fw` is None or `initial_state_bw` is None and
+ `dtype` is not provided.
+ """
+ # Keep this code in sync with tf.nn.dynamic_rnn for compatibility.
+ with variable_scope.variable_scope(scope or 'bidirectional_rnn'):
+ # Forward direction
+ with variable_scope.variable_scope('fw') as fw_scope:
+ output_fw, output_state_fw = functional_rnn(
+ cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
+ initial_state=initial_state_fw, dtype=dtype,
+ time_major=time_major, scope=fw_scope, use_tpu=use_tpu)
+ # Backward direction
+ if not time_major:
+ time_dim = 1
+ batch_dim = 0
+ else:
+ time_dim = 0
+ batch_dim = 1
+
+ def _reverse(input_, seq_lengths, seq_dim, batch_dim):
+ if seq_lengths is not None:
+ return array_ops.reverse_sequence(
+ input=input_, seq_lengths=seq_lengths,
+ seq_dim=seq_dim, batch_dim=batch_dim)
+ else:
+ # See b/69305369.
+ assert not use_tpu, (
+ 'Bidirectional with variable sequence lengths unsupported on TPU')
+ return array_ops.reverse(input_, axis=[seq_dim])
+
+ with variable_scope.variable_scope('bw') as bw_scope:
+ inputs_reverse = _reverse(
+ inputs, seq_lengths=sequence_length,
+ seq_dim=time_dim, batch_dim=batch_dim)
+ tmp, output_state_bw = functional_rnn(
+ cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
+ initial_state=initial_state_bw, dtype=dtype,
+ time_major=time_major, scope=bw_scope, use_tpu=use_tpu)
+
+ output_bw = _reverse(
+ tmp, seq_lengths=sequence_length,
+ seq_dim=time_dim, batch_dim=batch_dim)
+
+ outputs = (output_fw, output_bw)
+ output_states = (output_state_fw, output_state_bw)
+
+ return (outputs, output_states)
+# pylint: enable=invalid-name
diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py
new file mode 100644
index 0000000000..fa16b82ab6
--- /dev/null
+++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py
@@ -0,0 +1,720 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Recurrent computation.
+
+The main interface of this module is Recurrent().
+A recurrent computation describes an auto-regressive process, where outputs
+of one time step are fed to the output of the next time step.
+
+This module uses:
+ theta: the "weights" each RNN uses.
+ state0: the initial state of each RNN.
+ cell_fn: A python function describing RNN cell. It must has the following
+ signature:
+ cell_fn: (theta, state0, inputs) -> (state1, extras)
+ state1 is the next RNN state, extras are computed by cell_fn
+ and the library forwards extras to cell_fn's gradient function.
+ cell_grad: A python function describing the backprop gradient function
+ for the RNN cell. It must has the following signature:
+ cell_grad: (theta, state0, inputs, extras, dstate1) -> (
+ dtheta, dstate0, dinputs)
+ dstate1 is what the backprop algorithm provides representing
+ gradients of state1 w.r.t. the final loss.
+
+In this module, we handle structures of tensors for theta, state0, inputs,
+and extras. The structure is an arbitrarily nested python structure, such
+as a dictionary of named tuples.
+
+Because the computation is a left-to-right chain, a single in-place accumulator
+can be used rather than a stack. Thus a special gradient was written to reduce
+unnecessary memory usage.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import inplace_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.inplace_ops import alias_inplace_update
+from tensorflow.python.util import nest
+
+
+def _AssertIsCompatible(a, b):
+ """Checks that `a` and `b` are nested structures of the same type."""
+ # TODO(drpng): implement.
+ del a
+ del b
+
+
+def _Index(struct, index):
+ """Returns a structure with `x[index]` for each tensor `x` in the structure.
+
+ Args:
+ struct: A structure of tensors.
+ index: A scalar integer tensor. Performance is better if `index` is
+ on the host memory.
+
+ Returns:
+ A structure of tensors congruent to `struct`.
+ For each key in `ret`, `rets[key] = struct[key][index]`.
+ """
+ index = ops.convert_to_tensor(index)
+ index.get_shape().assert_has_rank(0)
+ return nest.map_structure(lambda x: x[index], struct)
+
+
+def _Update(struct_acc, struct_x, t):
+ """Updates t-th row in accumulators.
+
+ Args:
+ struct_acc: The accumulators. A structure of tensors.
+ struct_x: The new values. A structure of tensors congruent to `struct_acc`.
+ t: A scalar integer. Performance is better if `t` is on the device
+ memory.
+
+ Returns:
+ A structure of tensors. Say, ret is a returned dictionary. Then, for
+ each key, we have:
+ ret[key] = struct_acc[key];
+ ret[key][t, :] = struct_x[key]
+ """
+ to_skip_update = set()
+ acc_lst = nest.flatten(struct_acc)
+ x_lst = nest.flatten(struct_x)
+ t = math_ops.to_int32([t]) # tf.to_int32 casts on-device tensors.
+ lst = []
+ for acc, x in zip(acc_lst, x_lst):
+ if acc in to_skip_update:
+ # Until b/62105730 is fixed, we need to avoid inplace update for tensors
+ # of rank 1. could reshape to handle it, but we don't really need the
+ # values applied to these, so just skip their modification.
+ lst += [acc]
+ else:
+ lst += [alias_inplace_update(acc, t, array_ops.expand_dims(x, 0))]
+ return nest.pack_sequence_as(struct_acc, lst)
+
+
+def _SeqLenDim(struct):
+ """Returns the 0-th dim size of tensors in a structure of tensors.
+
+ This is the max sequence length according to the shape of the inputs.
+
+ Args:
+ struct: A structure of tensors. Every tensor's 0-th dim has the same size.
+
+ Returns:
+ A scalar tensor which is the size of 0-th dim of every tensors in struct.
+ """
+ xs = nest.flatten(struct)
+ assert xs
+ dim0 = array_ops.shape(xs[0])[0]
+ return dim0
+
+
+def _Flatten(struct):
+ """Flattens a structure."""
+ return nest.flatten(struct)
+
+
+def _Pack(elements, struct_template):
+ """Packs the list of tensors according to the structure.
+
+ In the event that `elements` should be a scalar, `struct_template` must
+ contain exactly one non-trivial element (for instance, `[[], {'x':elt}]`).
+
+ Args:
+ elements: Elements to be packed. A list of tensor, or a single tensor.
+ struct_template: The container structure in which to pack them.
+ Returns:
+ A python structure of the same type as `struct_template`, containing
+ `elements` as its contained elements.
+ """
+ if not nest.is_sequence(elements):
+ return nest.pack_sequence_as(struct_template, [elements])
+ return nest.pack_sequence_as(struct_template, elements)
+
+
+def _EmptyAcc(slen, struct_template):
+ """Creates a set of accumulators for tensors in structure.
+
+ Args:
+ slen: The sequence length. A scalar tensor.
+ struct_template: A structure of tensors.
+
+ Returns:
+ A structure congruent to `struct_template`. Say ret is a returned
+ dictionary. Then, `ret.key`, a tensor, has the same dtype as
+ `struct_template.key`. The tensor's shape has 1 more dimension
+ than the tensor `struct_template.key`. The extra 0-th dimension is of size
+ `slen`. E.g., if `slen=10` and `struct_template.key`'s shape is `[3, 5]`,
+ then, `ret.key`'s shape is `[10, 3, 5]`.
+ """
+
+ def _EmptyAccForTensor(tensor):
+ return inplace_ops.empty(
+ array_ops.concat([[slen], array_ops.shape(tensor)], axis=0),
+ tensor.dtype,
+ init=True)
+
+ return nest.map_structure(_EmptyAccForTensor, struct_template)
+
+
+def _EmptyLike(struct):
+ """Creates a set of empty initialized tensors.
+
+ Args:
+ struct: A structure of tensors.
+
+ Returns:
+ A struct of tensors. Each tensor has the same shape and dtype as
+ its corresponding tensor in `struct`. And each tensor is initialized.
+ """
+ return nest.map_structure(
+ lambda x: inplace_ops.empty_like(x, init=True), struct)
+
+
+def _Add(struct_x, struct_y):
+ """Adds tensors in `struct_x` with respective tensors in `struct_y`.
+
+ Args:
+ struct_x: A struct of tensors.
+ struct_y: A struct of tensors congruent to `struct_x`.
+
+ Returns:
+ A struct of tensors. Each element of the returned value
+ equals `x + y`, with corresponding values in `struct_x` and `struct_y`.
+ """
+ list_x = nest.flatten(struct_x)
+ list_y = nest.flatten(struct_y)
+ z = []
+ for x, y in zip(list_x, list_y):
+ z += [math_ops.add(x, y)]
+ return nest.pack_sequence_as(struct_x, z)
+
+
+def _Dtypes(struct):
+ """Returns all tensors' data types in a list."""
+ return [x.dtype for x in nest.flatten(struct)]
+
+
+def _ConvertNoneGradientToZeros(xs, dxs):
+ """Sanitize dxs so that None becomes zeros appropriately.
+
+ Args:
+ xs: A list of tensors.
+ dxs: A list of tensors. dxs[i] corresponds to xs[i]'s gradient.
+
+ Returns:
+ A structure same as `dxs` with `None` replaced by a zero tensor.
+ """
+ list_xs = nest.flatten(xs)
+ list_dxs = nest.flatten(dxs)
+
+ # If x does not get any backprop-ed gradient, propagate zeros.
+ rets = []
+ for (x, dx) in zip(list_xs, list_dxs):
+ if dx is None:
+ rets.append(array_ops.zeros_like(x))
+ else:
+ rets.append(dx)
+
+ return nest.pack_sequence_as(dxs, rets)
+
+
+# All structures are flattened for use internally. This is for simplicity
+# and also to use the Defun construct.
+# In the forward pass (inference), the computation is structured as follows.
+# Forward: [gradient = _Recurrent.Grad]
+# Flatten structures, create accumulators.
+# for t = 0..max_input_length:
+# Defun ForwardLoopBody:
+# Defun Fwd: flatten/pack around cell_fn
+# state1 = Fwd(inputs[t], state0)
+# acc_state += [state1]
+# Pack structures.
+# During the backward pass (backpropping the gradient from the last time
+# step to the first, through the structure), the computation is structured
+# as follows.
+# Grad:
+# Flatten structures.
+# Defun Backward:
+# Create create accumulated derivatives: d_theta, d_inputs, d_acc_state.
+# Regarding the note at the top of the file, there is only one accumulator
+# for d_theta accumulated over the whole sequence.
+# for t = max_input_length -1..0:
+# Defun BackwardLoopBody:
+# Retrieve acc_state[t] computed in the forward pass.
+# Defun Bak: flatten/back around cell_fn_grad.
+# d_state1 is d_state0 from previous step (ie next time).
+# d_acc_state[dev_t] += d_state1
+# d_theta_t, d_state0, d_inputs_t, = Bak()
+# d_inputs[dev_t] += d_inputs
+# d_theta += d_theta_t
+# d_acc_state[t] += d_state1
+# Pack structures and return.
+class _Recurrent(object):
+ """A helper class to construct a recurrent neural net."""
+
+ def __init__(self, cell_fn, cell_grad, theta, state0, inputs,
+ max_input_length, extras, use_tpu):
+ """RNN helper class.
+
+ Args:
+ cell_fn: A python function, which computes:
+ state1, extras = cell_fn(theta, state0, inputs[t, :])
+ cell_grad: A python function which computes:
+ dtheta, dstate0, dinputs[t, :] = cell_grad(
+ theta, state0, inputs[t, :], extras, dstate1)
+ theta: weights. A structure of tensors.
+ state0: initial state. A structure of tensors.
+ inputs: inputs. A structure of tensors.
+ max_input_length: None, or the maximum effective length of the input over
+ all batches. A scalar tensor.
+ extras: A structure of tensors. The 2nd return value of every
+ invocation of cell_fn is a structure of tensors with matching keys
+ and shapes of this `extras`.
+ use_tpu: A boolean indicating whether the computation is mean to
+ run on a TPU.
+ """
+ self._theta = theta
+ self._state = state0
+ self._inputs = inputs
+ self._max_input_length = self._MaybeComputeMaxInputLength(
+ inputs, max_input_length)
+ self._cell_fn = cell_fn
+ self._cell_grad = cell_grad
+ self._extras = extras
+
+ # pylint: disable=unbalanced-tuple-unpacking
+
+ # NOTE: TF Function (Fwd, Bak, ForwardLoopBody, BackwardLoopBody,
+ # Forward and Backward defined below) simply takes a list of
+ # Tensors and returns a list of Tensors. When we pass in a
+ # structure (a list of structures of Tensors), we use _Flatten to
+ # convert the structure into a list of tensor. Conversely, the
+ # following code often uses _Pack to formulate a structure from a
+ # list of tensors based on a "template".
+
+ # Wraps cell_fn in a TF Function:
+ # state1 = cell_fn(theta, state0, inputs)
+ fwd_sig = [self._theta, self._state, self._inputs]
+
+ compiled = use_tpu
+ noinline = not compiled
+ dev_t_type = dtypes.int32 if use_tpu else dtypes.int64
+
+ @function.Defun(*_Dtypes(fwd_sig))
+ def Fwd(*args):
+ (theta, state0, inputs) = _Pack(args, fwd_sig)
+ state1, extras = self._cell_fn(theta, state0, inputs)
+ assert not function.get_extra_args(), (
+ 'cell_fn is not pure with extra args: %s.' %
+ (function.get_extra_args()))
+ _AssertIsCompatible(state1, self._state)
+ _AssertIsCompatible(extras, self._extras)
+ return _Flatten([state1, extras])
+
+ # Wraps cell_fn in a TF Function as a for-loop's body.
+ #
+ # The loop state is composed of:
+ # t: The loop variable. Timestep id.
+ # dev_t: The loop variable mirrored on the device.
+ # theta: the recurrent net's weights.
+ # state0: the previous recurrent state.
+ # inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
+ # acc_state: Each timestep's computed new state is also stashed into
+ # acc_state.
+ # acc_extras: Each timestep's computed extras is stashed into acc_extras
+ fwdloop_sig = [
+ self._theta, self._state, self._inputs, self._state, self._extras
+ ]
+
+ @function.Defun(dtypes.int32, dev_t_type, *_Dtypes(fwdloop_sig))
+ def ForwardLoopBody(*args):
+ """The body of forward loop."""
+ t, dev_t = args[0], args[1]
+ (theta, state0, inputs, acc_state, acc_extras) = _Pack(
+ args[2:], fwdloop_sig)
+ inputs_t = _Index(inputs, t) # external input at time step t.
+ fwd = Fwd(*_Flatten([theta, state0, inputs_t]))
+ state1, extras = _Pack(fwd, [self._state, self._extras])
+ # Saves state1 and extras in their accumulators.
+ acc_state = _Update(acc_state, state1, dev_t)
+ acc_extras = _Update(acc_extras, extras, dev_t)
+
+ return [math_ops.add(dev_t, 1)] + _Flatten(
+ [theta, state1, inputs, acc_state, acc_extras])
+
+ def Grad(op, *args):
+ """The python grad function for the Forward function."""
+
+ # NOTE: tf.gradient backprops None for int32/int64 while zeros
+ # for float32/float64. For consistency, we always backprop
+ # zeros.
+ args = list(args)
+ for i, dy in enumerate(args):
+ if dy is None:
+ args[i] = array_ops.zeros_like(op.outputs[i])
+ # TODO(drpng): getting the extra state here?
+ op_inputs = [x for x in op.inputs]
+ op_struct = [
+ self._theta, self._state, self._inputs, self._max_input_length,
+ self._extras
+ ]
+ (theta, state0, inputs, max_input_length, _) = _Pack(op_inputs, op_struct)
+ # acc_state and acc_extras are computed by the Forward pass and
+ # needed by the Backward pass.
+ acc_state, _, acc_extras = _Pack([x for x in op.outputs],
+ [self._state, self._state, self._extras])
+
+ # Forward computes acc_state, the final state and
+ # acc_extras. tf.gradients gives us their gradients w.r.t. the
+ # final loss. Because acc_extras are not exposed by Compute(),
+ # it has no gradients w.r.t. the final loss (i.e., by
+ # construction, it must be zeros).
+ d_acc_state, d_state1, _ = _Pack(args,
+ [self._state, self._state, self._extras])
+ return Backward(*_Flatten([
+ theta, state0, inputs, max_input_length, acc_state, acc_extras,
+ d_acc_state, d_state1
+ ]))
+
+ # Forward calls ForwardLoopBody n times. Each time computes one
+ # time step of the recurrent net.
+ forward_sig = [
+ self._theta, self._state, self._inputs, self._max_input_length,
+ self._extras
+ ]
+
+ @function.Defun(
+ *_Dtypes(forward_sig), python_grad_func=Grad, noinline=noinline)
+ def Forward(*args):
+ """Forward pass of the recurrent net."""
+ theta, state0, inputs, max_input_length, extras = _Pack(args, forward_sig)
+
+ slen_dim = _SeqLenDim(inputs)
+
+ # Creates accumulators for state0 and extras.
+ acc_state = _EmptyAcc(slen_dim, state0)
+ acc_extras = _EmptyAcc(slen_dim, extras)
+
+ dev_t = array_ops.constant(0, dtype=dev_t_type)
+ run = functional_ops.For(
+ start=0,
+ limit=max_input_length,
+ delta=1,
+ inputs=[dev_t] + _Flatten(
+ [theta, state0, inputs, acc_state, acc_extras]),
+ body=ForwardLoopBody,
+ rewrite_with_while=compiled)
+ _, state1, _, acc_state, acc_extras = _Pack(
+ run[1:],
+ [self._theta, self._state, self._inputs, self._state, self._extras])
+
+ return _Flatten([acc_state, state1, acc_extras])
+
+ # The per-step backward computes:
+ # d_theta, d_state0, d_inputs = cell_grad(
+ # theta, state0, inputs, extras, d_state1)
+ # where d_state1 is the backprop-ed gradient for state1, and
+ # extras is the computed by the forward step to facilitate the
+ # backward step.
+ bak_sig = [
+ self._theta, self._state, self._inputs, self._extras, self._state
+ ]
+
+ @function.Defun(*_Dtypes(bak_sig))
+ def Bak(*args):
+ """Backward step."""
+ (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig)
+ (dtheta, dstate0, dinputs) = self._cell_grad(theta, state0, inputs,
+ extras, d_state1)
+ assert not function.get_extra_args(), (
+ 'cell_grad is not pure with extra args: %s.' %
+ (function.get_extra_args()))
+ _AssertIsCompatible(dtheta, self._theta)
+ _AssertIsCompatible(dstate0, self._state)
+ _AssertIsCompatible(dinputs, self._inputs)
+ return _Flatten(
+ _ConvertNoneGradientToZeros([theta, state0, inputs],
+ [dtheta, dstate0, dinputs]))
+
+ # Define defuns used by a functional_ops.If in BackwardLoopBody.
+ state_if_sig = [self._state, self._state]
+
+ @function.Defun(*_Dtypes(state_if_sig))
+ def ReturnOrigState0(*args):
+ """Returns original state0 from inputs."""
+ (_, orig_state0) = _Pack(args, state_if_sig)
+ return nest.flatten(orig_state0)
+
+ @function.Defun(*_Dtypes(state_if_sig))
+ def ReturnAccState(*args):
+ """Returns acc_state[t-1] from inputs."""
+ (acc_state, _) = _Pack(args, state_if_sig)
+ return nest.flatten(acc_state)
+
+ # Wraps cell_grad gradient function in a TF Function as a
+ # for-loop's body for the Backward pass.
+ #
+ # The loop state is composed of:
+ # t: The loop variable. Timestep id.
+ # state0: the initial state for the entire backward loop.
+ # dev_t: The loop variable mirrored on the device.
+ # theta: the recurrent net's weights.
+ # inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
+ # acc_state: Each timestep's computed new state was stashed into
+ # acc_state by the Forward pass.
+ # acc_extras: Each timestep's computed extras was stashed into
+ # acc_extras by the Forward pass.
+ # d_theta: All timestep's gradient for theta is accumulated (added) into
+ # d_theta.
+ # d_state1: The backprop-ed gradient for the new stated computed by
+ # timestep t.
+ # d_inputs: d_inputs[t, :] is populated by the backward time step t.
+ # d_acc_state: The backprop-ed gradient for acc_state.
+ bakloop_sig = [
+ self._theta, self._state, self._inputs, self._state, self._extras,
+ self._theta, self._state, self._inputs, self._state
+ ]
+
+ @function.Defun(dtypes.int32, dev_t_type, *_Dtypes(bakloop_sig))
+ def BackwardLoopBody(*args):
+ """Backward loop body function."""
+ t, dev_t = args[0], args[1]
+ (theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state1,
+ d_inputs, d_acc_state) = _Pack(args[2:], bakloop_sig)
+
+ # The input recurrent state for time step t is previous time step's
+ # output, or the original state0 when on time step 0.
+ state_from_acc = _Index(acc_state, math_ops.maximum(0, t - 1))
+ state0 = functional_ops.If(
+ math_ops.equal(t, array_ops.constant(0, dtypes.int32)),
+ _Flatten([state_from_acc, orig_state0]), ReturnOrigState0,
+ ReturnAccState)
+ state0 = nest.pack_sequence_as(orig_state0, state0)
+
+ # The external inputs for time step t.
+ inputs_t = _Index(inputs, t)
+ # The extras for time step t.
+ extras_t = _Index(acc_extras, t)
+
+ d_state1 = _Add(_Index(d_acc_state, t), d_state1)
+ (d_theta_t, d_state0, d_inputs_t) = _Pack(
+ Bak(*_Flatten([theta, state0, inputs_t, extras_t, d_state1])),
+ [self._theta, self._state, self._inputs])
+ d_theta = _Add(d_theta, d_theta_t)
+ d_inputs = _Update(d_inputs, d_inputs_t, dev_t)
+ return [math_ops.subtract(dev_t, 1)] + _Flatten([
+ theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state0,
+ d_inputs, d_acc_state
+ ])
+
+ # Backward calls BackwardLoopBody n times. Each time computes the backprop
+ # for one time step of the recurrent net.
+ backward_sig = [
+ self._theta, self._state, self._inputs, self._max_input_length,
+ self._state, self._extras, self._state, self._state
+ ]
+
+ @function.Defun(*_Dtypes(backward_sig), noinline=noinline)
+ def Backward(*args):
+ """Backward pass for the recurrent net."""
+ # theta, state0, inputs are Forward's inputs.
+ # acc_state is the accumulated 1st output of Forward.
+ # acc_extras is the accumulated 2nd output of Forward.
+ # d_acc_state is the gradient for acc_state.
+ # d_state1 is the gradient for the final state computed by Forward.
+ (theta, state0, inputs, max_input_length, acc_state, acc_extras,
+ d_acc_state, d_state1) = _Pack(args, backward_sig)
+
+ # Accumulators for gradients.
+ d_theta = _EmptyLike(theta)
+ d_inputs = _EmptyLike(inputs)
+
+ # Loop backwards. Note the loop's limit is open-ended, so goes through
+ # t=0.
+ t = max_input_length - 1
+ dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t)
+ run = functional_ops.For(
+ start=t,
+ limit=-1,
+ delta=-1,
+ inputs=[dev_t] + _Flatten([
+ theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1,
+ d_inputs, d_acc_state
+ ]),
+ body=BackwardLoopBody,
+ rewrite_with_while=compiled)
+
+ (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0,
+ d_inputs, d_acc_state) = _Pack(run[1:], bakloop_sig)
+
+ d_max_input_length = array_ops.constant(0, dtype=max_input_length.dtype)
+ return _Flatten(
+ [d_theta, d_state0, d_inputs, d_max_input_length, acc_extras])
+
+ self._forward = Forward
+
+ def _MaybeComputeMaxInputLength(self, inputs, max_input_length):
+ if max_input_length is not None:
+ return max_input_length
+ return math_ops.reduce_max(array_ops.shape(nest.flatten(inputs)[0])[0])
+
+ def Compute(self):
+ return _Pack(
+ self._forward(*_Flatten([
+ self._theta, self._state, self._inputs, self._max_input_length,
+ self._extras
+ ])), [self._state, self._state, self._extras])[:2]
+
+
+def _GetCellGrad(cell_fn, cell_grad):
+ """Returns the gradient function for cell_fn.
+
+ Args:
+ cell_fn: The recurrent neural net's cell function.
+ cell_grad: If not None, cell_fn's gradient function.
+
+ Returns:
+ Returns cell_grad if not None. Otherwise, assume cell_fn is a python
+ function representing the recurrent neural net's cell function, i.e.,
+ cell_fn: (theta, state0, inputs) -> (state1, extra)
+ returns its default gradient python function, i.e.,
+ cell_grad: (theta, state0, inputs, extras, dstate1) -> (
+ dtheta, dstate0, dinputs)
+ """
+
+ if cell_grad:
+ return cell_grad
+
+ def CellGrad(theta, state0, inputs, extras, dstate1):
+ """Default gradient function for cell_fn."""
+ # NOTE: The default grad function recomputes the forward
+ # function and does not take advantage of 'extras' returned by
+ # the forward function.
+ del extras
+ state1, extras = cell_fn(theta, state0, inputs)
+ ys = _Flatten([state1])
+ xs = _Flatten([theta, state0, inputs])
+ grad_ys = _Flatten([dstate1])
+ grads = gradients_impl.gradients(ys=ys, xs=xs, grad_ys=grad_ys)
+ return _ConvertNoneGradientToZeros([theta, state0, inputs],
+ _Pack(grads, [theta, state0, inputs]))
+
+ return CellGrad
+
+
+def _IsSingleTimeStep(inputs, max_input_length):
+ """Returns True only if the time dimension of inputs is 1."""
+ if not isinstance(max_input_length, ops.Tensor):
+ return max_input_length == 1
+ for x in nest.flatten(inputs):
+ if x.shape.dims is None or x.shape[0].value != 1:
+ return False
+ return True
+
+
+def Recurrent(theta,
+ state0,
+ inputs,
+ cell_fn,
+ cell_grad=None,
+ extras=None,
+ max_input_length=None,
+ use_tpu=False):
+ """Compute a recurrent neural net.
+
+ Roughly, Recurrent() computes the following:
+ state = state0
+ for t in inputs' sequence length:
+ state = cell_fn(theta, state, inputs[t, :])
+ accumulate_state[t, :] = state
+ return accumulate_state, state
+
+ theta, state, inputs are all structures of tensors.
+
+ inputs[t, :] means taking a slice out from every tensor in the inputs.
+
+ accumulate_state[t, :] = state means that we stash every tensor in
+ 'state' into a slice of the corresponding tensor in
+ accumulate_state.
+
+ cell_fn is a python callable computing (building up a TensorFlow
+ graph) the recurrent neural network's one forward step. Two calls of
+ cell_fn must describe two identical computations.
+
+ By construction, Recurrent()'s backward computation does not access
+ any intermediate values computed by cell_fn during forward
+ computation. We may extend Recurrent() to support that by taking a
+ customized backward function of cell_fn.
+
+ Args:
+ theta: weights. A structure of tensors.
+ state0: initial state. A structure of tensors.
+ inputs: inputs. A structure of tensors.
+ cell_fn: A python function, which computes:
+ state1, extras = cell_fn(theta, state0, inputs[t, :])
+ cell_grad: A python function which computes:
+ dtheta, dstate0, dinputs[t, :] = cell_grad(
+ theta, state0, inputs[t, :], extras, dstate1)
+ extras: A structure of tensors. The 2nd return value of every
+ invocation of cell_fn is a structure of tensors with matching keys
+ and shapes of this `extras`.
+ max_input_length: maximum length of effective input. This is used to
+ truncate the computation if the inputs have been allocated to a
+ larger size. A scalar tensor.
+ use_tpu: whether or not we are on TPU.
+
+ Returns:
+ accumulate_state and the final state.
+ """
+ if cell_grad is None and _IsSingleTimeStep(inputs, max_input_length):
+ # The seqlen length is staticly known as 1. Hence, we just need to
+ # call cell_fn once without putting it into a loop.
+ inputs = nest.map_structure(lambda x: array_ops.squeeze(x, axis=0), inputs)
+ state1, _ = cell_fn(theta, state0, inputs)
+ acc_state = nest.map_structure(lambda x: array_ops.expand_dims(x, axis=0),
+ state1)
+ return acc_state, state1
+
+ # If cell_grad is not given, derives the gradient function from
+ # cell_fn.
+ cell_grad = _GetCellGrad(cell_fn, cell_grad)
+
+ if extras is None:
+ # Derives 'extras' so that we can allocate extras' accumulator.
+ _, extras = cell_fn(theta, state0, _Index(inputs, 0))
+ extras = nest.map_structure(array_ops.zeros_like, extras)
+ else:
+ _, actual = cell_fn(theta, state0, _Index(inputs, 0))
+ _AssertIsCompatible(extras, actual)
+
+ return _Recurrent(
+ cell_fn=cell_fn,
+ cell_grad=cell_grad,
+ theta=theta,
+ state0=state0,
+ inputs=inputs,
+ max_input_length=max_input_length,
+ extras=extras,
+ use_tpu=use_tpu).Compute()
diff --git a/tensorflow/contrib/recurrent/python/recurrent_api.py b/tensorflow/contrib/recurrent/python/recurrent_api.py
new file mode 100644
index 0000000000..ffe1dcf7dc
--- /dev/null
+++ b/tensorflow/contrib/recurrent/python/recurrent_api.py
@@ -0,0 +1,29 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Recurrent computations library."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+from tensorflow.contrib.recurrent.python.ops import functional_bidirectional_rnn
+from tensorflow.contrib.recurrent.python.ops import functional_rnn
+from tensorflow.contrib.recurrent.python.ops import Recurrent
+# pylint: enable=unused-import
+
+del absolute_import
+del division
+del print_function