aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-05-22 17:32:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-22 17:36:11 -0700
commit827d2e4b9180db67853f60c125e548d83986b96c (patch)
tree1ccaf8f20bf678ec755330b488eb28946dbe38e6
parent95719e869c61c78a4b0ac0407e1fb04e60daca35 (diff)
Move many of the "core" RNNCells and rnn functions back to TF core.
Unit test files will move in a followup PR. This is the big API change. The old behavior (using tf.contrib.rnn....) will continue to work for backwards compatibility. PiperOrigin-RevId: 156809677
-rw-r--r--RELEASE.md31
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py4
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py8
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py12
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py14
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py4
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py95
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py47
-rw-r--r--tensorflow/contrib/ndlstm/python/lstm1d.py8
-rw-r--r--tensorflow/contrib/rnn/__init__.py43
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py196
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py157
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py23
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py18
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py69
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py133
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py16
-rw-r--r--tensorflow/contrib/rnn/python/ops/core_rnn.py357
-rw-r--r--tensorflow/contrib/rnn/python/ops/core_rnn_cell.py232
-rw-r--r--tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py1048
-rw-r--r--tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py14
-rw-r--r--tensorflow/contrib/rnn/python/ops/gru_ops.py4
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py15
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn.py3
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py66
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py8
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py26
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py4
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py8
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py3
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py6
-rw-r--r--tensorflow/python/BUILD9
-rw-r--r--tensorflow/python/debug/lib/session_debug_testlib.py2
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py4
-rw-r--r--tensorflow/python/ops/nn.py9
-rw-r--r--tensorflow/python/ops/rnn.py358
-rw-r--r--tensorflow/python/ops/rnn_cell.py51
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py834
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt95
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt95
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt95
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt95
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt95
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt95
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt27
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt95
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt94
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt95
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt43
51 files changed, 2887 insertions, 1996 deletions
diff --git a/RELEASE.md b/RELEASE.md
index 02bdbd4297..ec24d6fd80 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -4,7 +4,7 @@
* Added `tf.layers.conv3d_transpose` layer for spatio temporal deconvolution.
* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times.
* Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
-* `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described
+* `RNNCell` objects now subclass `tf.layers.Layer`. The strictness described
in the TensorFlow 1.1 release is gone: The first time an RNNCell is used,
it caches its scope. All future uses of the RNNCell will reuse variables from
that same scope. This is a breaking change from the behavior of RNNCells
@@ -17,18 +17,33 @@
parameters, write: `MultiRNNCell([LSTMCell(...) for _ in range(5)])`.
If at all unsure, first test your code with TF 1.1; ensure it raises no
errors, and then upgrade to TF 1.2.
+* RNNCells' variable names have been renamed for consistency with Keras layers.
+ Specifically, the previous variable names "weights" and "biases" have
+ been changed to "kernel" and "bias", respectively.
+ This may cause backward incompatibility with regard to your old
+ checkpoints containing such RNN cells, in which case you can use the tool
+ [checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py)
+ to convert the variable names in your old checkpoints.
+* Many of the RNN functions and classes that were in the `tf.nn` namespace
+ before the 1.0 release and which were moved to `tf.contrib.rnn` have now
+ been moved back to the core namespace. This includes
+ `RNNCell`, `LSTMCell`, `GRUCell`, and a number of other cells. These
+ now reside in `tf.nn.rnn_cell` (with aliases in `tf.contrib.rnn` for backwards
+ compatibility). The original `tf.nn.rnn` function is now `tf.nn.static_rnn`,
+ and the bidirectional static and state saving static rnn functions are also
+ now back in the `tf.nn` namespace.
+
+ Notable exceptions are the `EmbeddingWrapper`, `InputProjectionWrapper` and
+ `OutputProjectionWrapper`, which will slowly be moved to deprecation
+ in `tf.contrib.rnn`. These are inefficient wrappers that should often
+ be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post-
+ processing of the rnn. For RNN decoding, this functionality has been replaced
+ with an alternative API in `tf.contrib.seq2seq`.
## Bug Fixes and Other Changes
* In python, `Operation.get_attr` on type attributes returns the Python DType
version of the type to match expected get_attr documentation rather than the
protobuf enum.
-* tensorflow/contrib/rnn undergoes RNN cell variable renaming for
- consistency with Keras layers. Specifically, the previous variable names
- "weights" and "biases" are changed to "kernel" and "bias", respectively.
- This may cause backward incompatibility with regard to your old
- checkpoints containing such RNN cells, in which case you can use the
- [checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py)
- to convert the variable names in your old checkpoints.
# Release 1.1.0
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 4bcf93e78f..a19c70717a 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -41,11 +41,11 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope as vs
__all__ = [
@@ -225,7 +225,7 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
return binary_scores
-class CrfForwardRnnCell(core_rnn_cell.RNNCell):
+class CrfForwardRnnCell(rnn_cell.RNNCell):
"""Computes the alpha values in a linear-chain CRF.
See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py
index 34b2d49d26..6ca38c2e47 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py
@@ -22,7 +22,6 @@ import time
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
from tensorflow.contrib.rnn.python.ops import core_rnn
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
@@ -31,6 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -131,9 +131,9 @@ class CudnnRNNBenchmark(test.Benchmark):
]
initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units=num_units, initializer=initializer, state_is_tuple=True)
- multi_cell = core_rnn_cell_impl.MultiRNNCell(
+ multi_cell = rnn_cell.MultiRNNCell(
[cell() for _ in range(num_layers)])
outputs, final_state = core_rnn.static_rnn(
multi_cell, inputs, dtype=dtypes.float32)
@@ -159,7 +159,7 @@ class CudnnRNNBenchmark(test.Benchmark):
]
cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop
- multi_cell = core_rnn_cell_impl.MultiRNNCell(
+ multi_cell = rnn_cell.MultiRNNCell(
[cell() for _ in range(num_layers)])
outputs, final_state = core_rnn.static_rnn(
multi_cell, inputs, dtype=dtypes.float32)
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index 280271a42d..fed8a771cc 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -21,11 +21,11 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.grid_rnn.python.ops import grid_rnn_cell
-from tensorflow.contrib.rnn.python.ops import core_rnn
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -527,7 +527,7 @@ class GridRNNCellTest(test.TestCase):
dtypes.float32, shape=(batch_size, input_size))
]
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
@@ -569,7 +569,7 @@ class GridRNNCellTest(test.TestCase):
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
]
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
@@ -609,7 +609,7 @@ class GridRNNCellTest(test.TestCase):
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
]
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
@@ -652,7 +652,7 @@ class GridRNNCellTest(test.TestCase):
dtypes.float32, shape=(batch_size, input_size))
] + (max_length - 1) * [array_ops.zeros([batch_size, input_size])])
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
@@ -690,7 +690,7 @@ class GridRNNCellTest(test.TestCase):
array_ops.placeholder(dtypes.float32, shape=(None, input_size))
]
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index 6fc028ab70..d518e38fe0 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -31,7 +31,6 @@ from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_f
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.learn.python.learn.estimators import rnn_common
from tensorflow.contrib.learn.python.learn.estimators import run_config
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -42,6 +41,7 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -107,7 +107,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
def setUp(self):
super(DynamicRnnEstimatorTest, self).setUp()
- self.rnn_cell = core_rnn_cell_impl.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
+ self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
self.mock_target_column = MockTargetColumn(
num_label_columns=self.NUM_LABEL_COLUMNS)
@@ -312,19 +312,19 @@ class DynamicRnnEstimatorTest(test.TestCase):
# A MultiRNNCell of LSTMCells is both a common choice and an interesting
# test case, because it has two levels of nesting, with an inner class that
# is not a plain tuple.
- cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes])
+ cell = rnn_cell.MultiRNNCell(
+ [rnn_cell.LSTMCell(i) for i in cell_sizes])
state_dict = {
dynamic_rnn_estimator._get_state_name(i):
array_ops.expand_dims(math_ops.range(cell_size), 0)
for i, cell_size in enumerate([5, 5, 3, 3, 7, 7])
}
- expected_state = (core_rnn_cell_impl.LSTMStateTuple(
+ expected_state = (rnn_cell.LSTMStateTuple(
np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])),
- core_rnn_cell_impl.LSTMStateTuple(
+ rnn_cell.LSTMStateTuple(
np.reshape(np.arange(3), [1, -1]),
np.reshape(np.arange(3), [1, -1])),
- core_rnn_cell_impl.LSTMStateTuple(
+ rnn_cell.LSTMStateTuple(
np.reshape(np.arange(7), [1, -1]),
np.reshape(np.arange(7), [1, -1])))
actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
index 9cb4c3515a..0cea35e219 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
@@ -26,13 +26,13 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators import rnn_common
-from tensorflow.contrib.rnn.python.ops import core_rnn
from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import rnn
from tensorflow.python.training import momentum as momentum_opt
from tensorflow.python.util import nest
@@ -64,7 +64,7 @@ def construct_state_saving_rnn(cell,
final_state: The final state output by the RNN
"""
with ops.name_scope(scope):
- rnn_outputs, final_state = core_rnn.static_state_saving_rnn(
+ rnn_outputs, final_state = rnn.static_state_saving_rnn(
cell=cell,
inputs=inputs,
state_saver=state_saver,
diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
index a9116c2d54..95aec61955 100644
--- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
@@ -21,9 +21,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.learn.python.learn import ops
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.platform import test
@@ -82,7 +82,7 @@ class Seq2SeqOpsTest(test.TestCase):
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]
encoding = array_ops.placeholder(dtypes.float32, [2, 2])
- cell = core_rnn_cell_impl.GRUCell(2)
+ cell = rnn_cell.GRUCell(2)
outputs, states, sampling_outputs, sampling_states = (
ops.rnn_decoder(decoder_inputs, encoding, cell))
self.assertEqual(len(outputs), 3)
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 2898935a47..4395138e20 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -25,8 +25,7 @@ import random
import numpy as np
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib
-from tensorflow.contrib.rnn.python.ops import core_rnn
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
+from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -37,6 +36,7 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -51,11 +51,10 @@ class Seq2SeqTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- _, enc_state = core_rnn.static_rnn(
- core_rnn_cell_impl.GRUCell(2), inp, dtype=dtypes.float32)
+ _, enc_state = rnn.static_rnn(
+ rnn_cell.GRUCell(2), inp, dtype=dtypes.float32)
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
- cell = core_rnn_cell_impl.OutputProjectionWrapper(
- core_rnn_cell_impl.GRUCell(2), 4)
+ cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
@@ -71,8 +70,7 @@ class Seq2SeqTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
- cell = core_rnn_cell_impl.OutputProjectionWrapper(
- core_rnn_cell_impl.GRUCell(2), 4)
+ cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
dec, mem = seq2seq_lib.basic_rnn_seq2seq(inp, dec_inp, cell)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
@@ -88,8 +86,7 @@ class Seq2SeqTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
- cell = core_rnn_cell_impl.OutputProjectionWrapper(
- core_rnn_cell_impl.GRUCell(2), 4)
+ cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
dec, mem = seq2seq_lib.tied_rnn_seq2seq(inp, dec_inp, cell)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
@@ -105,9 +102,9 @@ class Seq2SeqTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
+ cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
cell = cell_fn()
- _, enc_state = core_rnn.static_rnn(cell, inp, dtype=dtypes.float32)
+ _, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
dec_inp = [
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
@@ -138,7 +135,7 @@ class Seq2SeqTest(test.TestCase):
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
- cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
+ cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
cell = cell_fn()
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
@@ -158,7 +155,7 @@ class Seq2SeqTest(test.TestCase):
# Test with state_is_tuple=False.
with variable_scope.variable_scope("no_tuple"):
- cell_nt = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ cell_nt = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -242,9 +239,7 @@ class Seq2SeqTest(test.TestCase):
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
- cell = functools.partial(
- core_rnn_cell_impl.BasicLSTMCell,
- 2, state_is_tuple=True)
+ cell = functools.partial(rnn_cell.BasicLSTMCell, 2, state_is_tuple=True)
dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2)
sess.run([variables.global_variables_initializer()])
@@ -324,11 +319,10 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: rnn_cell.GRUCell(2)
cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = core_rnn.static_rnn(
- cell, inp, dtype=dtypes.float32)
+ enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
@@ -350,11 +344,10 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: rnn_cell.GRUCell(2)
cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = core_rnn.static_rnn(
- cell, inp, dtype=dtypes.float32)
+ enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
@@ -377,7 +370,7 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: rnn_cell.GRUCell(2)
cell = cell_fn()
inp = constant_op.constant(0.5, shape=[2, 2, 2])
enc_outputs, enc_state = rnn.dynamic_rnn(
@@ -401,7 +394,7 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: rnn_cell.GRUCell(2)
cell = cell_fn()
inp = constant_op.constant(0.5, shape=[2, 2, 2])
enc_outputs, enc_state = rnn.dynamic_rnn(
@@ -426,14 +419,13 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda
+ single_cell = lambda: rnn_cell.BasicLSTMCell( # pylint: disable=g-long-lambda
2, state_is_tuple=True)
- cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda
+ cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
cells=[single_cell() for _ in range(2)], state_is_tuple=True)
cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = core_rnn.static_rnn(
- cell, inp, dtype=dtypes.float32)
+ enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
@@ -459,12 +451,11 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda
- cells=[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)])
+ cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
+ cells=[rnn_cell.BasicLSTMCell(2) for _ in range(2)])
cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = core_rnn.static_rnn(
- cell, inp, dtype=dtypes.float32)
+ enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
array_ops.reshape(e, [-1, 1, cell.output_size])
for e in enc_outputs
@@ -492,10 +483,9 @@ class Seq2SeqTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: rnn_cell.GRUCell(2)
cell = cell_fn()
- enc_outputs, enc_state = core_rnn.static_rnn(
- cell, inp, dtype=dtypes.float32)
+ enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
@@ -534,7 +524,7 @@ class Seq2SeqTest(test.TestCase):
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
- cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
+ cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
cell = cell_fn()
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
@@ -555,8 +545,7 @@ class Seq2SeqTest(test.TestCase):
# Test with state_is_tuple=False.
with variable_scope.variable_scope("no_tuple"):
cell_fn = functools.partial(
- core_rnn_cell_impl.BasicLSTMCell,
- 2, state_is_tuple=False)
+ rnn_cell.BasicLSTMCell, 2, state_is_tuple=False)
cell_nt = cell_fn()
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
@@ -651,11 +640,10 @@ class Seq2SeqTest(test.TestCase):
]
dec_symbols_dict = {"0": 5, "1": 6}
def EncCellFn():
- return core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ return rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
def DecCellsFn():
- return dict(
- (k, core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True))
- for k in dec_symbols_dict)
+ return dict((k, rnn_cell.BasicLSTMCell(2, state_is_tuple=True))
+ for k in dec_symbols_dict)
outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq(
enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(),
2, dec_symbols_dict, embedding_size=2))
@@ -796,8 +784,8 @@ class Seq2SeqTest(test.TestCase):
# """Example sequence-to-sequence model that uses GRU cells."""
# def GRUSeq2Seq(enc_inp, dec_inp):
- # cell = core_rnn_cell_impl.MultiRNNCell(
- # [core_rnn_cell_impl.GRUCell(24) for _ in range(2)])
+ # cell = rnn_cell.MultiRNNCell(
+ # [rnn_cell.GRUCell(24) for _ in range(2)])
# return seq2seq_lib.embedding_attention_seq2seq(
# enc_inp,
# dec_inp,
@@ -862,9 +850,8 @@ class Seq2SeqTest(test.TestCase):
"""Example sequence-to-sequence model that uses GRU cells."""
def GRUSeq2Seq(enc_inp, dec_inp):
- cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(24) for _ in range(2)],
- state_is_tuple=True)
+ cell = rnn_cell.MultiRNNCell(
+ [rnn_cell.GRUCell(24) for _ in range(2)], state_is_tuple=True)
return seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,
@@ -1040,7 +1027,7 @@ class Seq2SeqTest(test.TestCase):
self.assertAllClose(v_true.eval(), v_false.eval())
def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
return seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -1051,7 +1038,7 @@ class Seq2SeqTest(test.TestCase):
feed_previous=feed_previous)
def EmbeddingRNNSeq2SeqNoTupleF(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
return seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -1062,7 +1049,7 @@ class Seq2SeqTest(test.TestCase):
feed_previous=feed_previous)
def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
return seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -1072,7 +1059,7 @@ class Seq2SeqTest(test.TestCase):
feed_previous=feed_previous)
def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
return seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -1082,7 +1069,7 @@ class Seq2SeqTest(test.TestCase):
feed_previous=feed_previous)
def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
return seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,
@@ -1093,7 +1080,7 @@ class Seq2SeqTest(test.TestCase):
feed_previous=feed_previous)
def EmbeddingAttentionSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
return seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index a80b898156..23b4a73b23 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -62,9 +62,7 @@ import copy
from six.moves import xrange # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
-from tensorflow.contrib.rnn.python.ops import core_rnn
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -72,11 +70,13 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
# TODO(ebrevdo): Remove once _linear is fully deprecated.
-linear = core_rnn_cell_impl._linear # pylint: disable=protected-access
+linear = rnn_cell_impl._linear # pylint: disable=protected-access
def _extract_argmax_and_embed(embedding,
@@ -119,7 +119,7 @@ def rnn_decoder(decoder_inputs,
Args:
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: rnn_cell.RNNCell defining the cell function and size.
loop_function: If not None, this function will be applied to the i-th output
in order to generate the i+1-st input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol). This can be used for decoding,
@@ -170,7 +170,7 @@ def basic_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
@@ -183,7 +183,7 @@ def basic_rnn_seq2seq(encoder_inputs,
"""
with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
enc_cell = copy.deepcopy(cell)
- _, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
+ _, enc_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_state, cell)
@@ -202,7 +202,7 @@ def tied_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
loop_function: If not None, this function will be applied to i-th output
in order to generate i+1-th input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol), see rnn_decoder for details.
@@ -219,7 +219,7 @@ def tied_rnn_seq2seq(encoder_inputs,
"""
with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
scope = scope or "tied_rnn_seq2seq"
- _, enc_state = core_rnn.static_rnn(
+ _, enc_state = rnn.static_rnn(
cell, encoder_inputs, dtype=dtype, scope=scope)
variable_scope.get_variable_scope().reuse_variables()
return rnn_decoder(
@@ -244,7 +244,7 @@ def embedding_rnn_decoder(decoder_inputs,
Args:
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
initial_state: 2D Tensor [batch_size x cell.state_size].
- cell: core_rnn_cell.RNNCell defining the cell function.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function.
num_symbols: Integer, how many symbols come into the embedding.
embedding_size: Integer, the length of the embedding vector for each symbol.
output_projection: None or a pair (W, B) of output projection weights and
@@ -320,7 +320,7 @@ def embedding_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols: Integer; number of symbols on the decoder side.
embedding_size: Integer, the length of the embedding vector for each symbol.
@@ -360,8 +360,7 @@ def embedding_rnn_seq2seq(encoder_inputs,
encoder_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- _, encoder_state = core_rnn.static_rnn(
- encoder_cell, encoder_inputs, dtype=dtype)
+ _, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
if output_projection is None:
@@ -431,7 +430,7 @@ def embedding_tied_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
num_symbols: Integer; number of symbols for both encoder and decoder.
embedding_size: Integer, the length of the embedding vector for each symbol.
num_decoder_symbols: Integer; number of output symbols for decoder. If
@@ -560,7 +559,7 @@ def attention_decoder(decoder_inputs,
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
initial_state: 2D Tensor [batch_size x cell.state_size].
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
output_size: Size of the output vectors; if None, we use cell.output_size.
num_heads: Number of attention heads that read from attention_states.
loop_function: If not None, this function will be applied to i-th output
@@ -720,7 +719,7 @@ def embedding_attention_decoder(decoder_inputs,
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
initial_state: 2D Tensor [batch_size x cell.state_size].
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
- cell: core_rnn_cell.RNNCell defining the cell function.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function.
num_symbols: Integer, how many symbols come into the embedding.
embedding_size: Integer, the length of the embedding vector for each symbol.
num_heads: Number of attention heads that read from attention_states.
@@ -814,7 +813,7 @@ def embedding_attention_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols: Integer; number of symbols on the decoder side.
embedding_size: Integer, the length of the embedding vector for each symbol.
@@ -851,7 +850,7 @@ def embedding_attention_seq2seq(encoder_inputs,
encoder_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- encoder_outputs, encoder_state = core_rnn.static_rnn(
+ encoder_outputs, encoder_state = rnn.static_rnn(
encoder_cell, encoder_inputs, dtype=dtype)
# First calculate a concatenation of encoder outputs to put attention on.
@@ -937,9 +936,10 @@ def one2many_rnn_seq2seq(encoder_inputs,
the corresponding decoder_inputs; each decoder_inputs is a list of 1D
Tensors of shape [batch_size]; num_decoders is defined as
len(decoder_inputs_dict).
- enc_cell: core_rnn_cell.RNNCell defining the encoder cell function and size.
+ enc_cell: tf.nn.rnn_cell.RNNCell defining the encoder cell function and
+ size.
dec_cells_dict: A dictionary mapping encoder name (string) to an
- instance of core_rnn_cell.RNNCell.
+ instance of tf.nn.rnn_cell.RNNCell.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
integer specifying number of symbols for the corresponding decoder;
@@ -971,12 +971,12 @@ def one2many_rnn_seq2seq(encoder_inputs,
outputs_dict = {}
state_dict = {}
- if not isinstance(enc_cell, core_rnn_cell.RNNCell):
+ if not isinstance(enc_cell, rnn_cell_impl.RNNCell):
raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell))
if set(dec_cells_dict) != set(decoder_inputs_dict):
raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict")
for dec_cell in dec_cells_dict.values():
- if not isinstance(dec_cell, core_rnn_cell.RNNCell):
+ if not isinstance(dec_cell, rnn_cell_impl.RNNCell):
raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell))
with variable_scope.variable_scope(
@@ -988,8 +988,7 @@ def one2many_rnn_seq2seq(encoder_inputs,
enc_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- _, encoder_state = core_rnn.static_rnn(
- enc_cell, encoder_inputs, dtype=dtype)
+ _, encoder_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
# Decoder.
for name, decoder_inputs in decoder_inputs_dict.items():
@@ -1153,7 +1152,7 @@ def model_with_buckets(encoder_inputs,
The seq2seq argument is a function that defines a sequence-to-sequence model,
e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(
- x, y, core_rnn_cell.GRUCell(24))
+ x, y, rnn_cell.GRUCell(24))
Args:
encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.
diff --git a/tensorflow/contrib/ndlstm/python/lstm1d.py b/tensorflow/contrib/ndlstm/python/lstm1d.py
index e4edff00a7..d3c3531f40 100644
--- a/tensorflow/contrib/ndlstm/python/lstm1d.py
+++ b/tensorflow/contrib/ndlstm/python/lstm1d.py
@@ -20,13 +20,13 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python.ops import variables
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
@@ -52,7 +52,7 @@ def ndlstm_base_unrolled(inputs, noutput, scope=None, reverse=False):
"""
with variable_scope.variable_scope(scope, "SeqLstmUnrolled", [inputs]):
length, batch_size, _ = _shape(inputs)
- lstm_cell = core_rnn_cell_impl.BasicLSTMCell(noutput, state_is_tuple=False)
+ lstm_cell = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False)
state = array_ops.zeros([batch_size, lstm_cell.state_size])
output_u = []
inputs_u = array_ops.unstack(inputs)
@@ -88,7 +88,7 @@ def ndlstm_base_dynamic(inputs, noutput, scope=None, reverse=False):
# TODO(tmb) make batch size, sequence_length dynamic
# example: sequence_length = tf.shape(inputs)[0]
_, batch_size, _ = _shape(inputs)
- lstm_cell = core_rnn_cell_impl.BasicLSTMCell(noutput, state_is_tuple=False)
+ lstm_cell = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False)
state = array_ops.zeros([batch_size, lstm_cell.state_size])
sequence_length = int(inputs.get_shape()[0])
sequence_lengths = math_ops.to_int64(
@@ -145,7 +145,7 @@ def sequence_to_final(inputs, noutput, scope=None, name=None, reverse=False):
"""
with variable_scope.variable_scope(scope, "SequenceToFinal", [inputs]):
length, batch_size, _ = _shape(inputs)
- lstm = core_rnn_cell_impl.BasicLSTMCell(noutput, state_is_tuple=False)
+ lstm = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False)
state = array_ops.zeros([batch_size, lstm.state_size])
inputs_u = array_ops.unstack(inputs)
if reverse:
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py
index 2420c3e179..d39c1f062a 100644
--- a/tensorflow/contrib/rnn/__init__.py
+++ b/tensorflow/contrib/rnn/__init__.py
@@ -16,21 +16,26 @@
See @{$python/contrib.rnn} guide.
+# From core
@@RNNCell
@@BasicRNNCell
@@BasicLSTMCell
@@GRUCell
@@LSTMCell
-@@LayerNormBasicLSTMCell
@@LSTMStateTuple
-@@MultiRNNCell
-@@LSTMBlockWrapper
@@DropoutWrapper
+@@MultiRNNCell
+@@DeviceWrapper
+@@ResidualWrapper
+
+# Used to be in core, but kept in contrib.
@@EmbeddingWrapper
@@InputProjectionWrapper
@@OutputProjectionWrapper
-@@DeviceWrapper
-@@ResidualWrapper
+
+# Created in contrib, eventual plans to move to core.
+@@LayerNormBasicLSTMCell
+@@LSTMBlockWrapper
@@LSTMBlockCell
@@GRUBlockCell
@@FusedRNNCell
@@ -48,9 +53,11 @@ See @{$python/contrib.rnn} guide.
@@HighwayWrapper
@@GLSTMCell
-### RNNCell wrappers
+# RNNCell wrappers
@@AttentionCellWrapper
@@CompiledWrapper
+
+# RNN functions
@@static_rnn
@@static_state_saving_rnn
@@static_bidirectional_rnn
@@ -62,31 +69,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.rnn.python.ops.core_rnn import static_bidirectional_rnn
-from tensorflow.contrib.rnn.python.ops.core_rnn import static_rnn
-from tensorflow.contrib.rnn.python.ops.core_rnn import static_state_saving_rnn
-
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicLSTMCell
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DeviceWrapper
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DropoutWrapper
+# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import EmbeddingWrapper
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import GRUCell
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import InputProjectionWrapper
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMCell
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import MultiRNNCell
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import ResidualWrapper
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import RNNCell
-# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.rnn.python.ops.fused_rnn_cell import *
from tensorflow.contrib.rnn.python.ops.gru_ops import *
from tensorflow.contrib.rnn.python.ops.lstm_ops import *
from tensorflow.contrib.rnn.python.ops.rnn import *
from tensorflow.contrib.rnn.python.ops.rnn_cell import *
+
+from tensorflow.python.ops.rnn import static_bidirectional_rnn
+from tensorflow.python.ops.rnn import static_rnn
+from tensorflow.python.ops.rnn import static_state_saving_rnn
+
+from tensorflow.python.ops.rnn_cell import *
# pylint: enable=unused-import,wildcard-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
-remove_undocumented(__name__, ['core_rnn_cell'])
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 0f207b088d..06954f51d8 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -25,8 +25,7 @@ import numpy as np
# TODO(ebrevdo): Remove once _linear is fully deprecated.
# pylint: disable=protected-access
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear
+from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -37,12 +36,14 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
# pylint: enable=protected-access
+linear = rnn_cell_impl._linear
class RNNCellTest(test.TestCase):
@@ -74,14 +75,12 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2])
- cell = core_rnn_cell_impl.BasicRNNCell(2)
+ cell = rnn_cell_impl.BasicRNNCell(2)
g, _ = cell(x, m)
- self.assertEqual(
- ["root/basic_rnn_cell/%s:0"
- % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
- "root/basic_rnn_cell/%s:0"
- % core_rnn_cell_impl._BIAS_VARIABLE_NAME],
- [v.name for v in cell.trainable_variables])
+ self.assertEqual([
+ "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
+ ], [v.name for v in cell.trainable_variables])
self.assertFalse(cell.non_trainable_variables)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
@@ -100,15 +99,13 @@ class RNNCellTest(test.TestCase):
custom_getter=not_trainable_getter):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2])
- cell = core_rnn_cell_impl.BasicRNNCell(2)
+ cell = rnn_cell_impl.BasicRNNCell(2)
g, _ = cell(x, m)
self.assertFalse(cell.trainable_variables)
- self.assertEqual(
- ["root/basic_rnn_cell/%s:0"
- % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
- "root/basic_rnn_cell/%s:0"
- % core_rnn_cell_impl._BIAS_VARIABLE_NAME],
- [v.name for v in cell.non_trainable_variables])
+ self.assertEqual([
+ "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
+ ], [v.name for v in cell.non_trainable_variables])
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
[g], {x.name: np.array([[1., 1.]]),
@@ -121,7 +118,7 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2])
- g, _ = core_rnn_cell_impl.GRUCell(2)(x, m)
+ g, _ = rnn_cell_impl.GRUCell(2)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
[g], {x.name: np.array([[1., 1.]]),
@@ -133,7 +130,7 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros(
[1, 3]) # Test GRUCell with input_size != num_units.
m = array_ops.zeros([1, 2])
- g, _ = core_rnn_cell_impl.GRUCell(2)(x, m)
+ g, _ = rnn_cell_impl.GRUCell(2)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
[g],
@@ -148,20 +145,23 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 8])
- cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.BasicLSTMCell(
- 2, state_is_tuple=False) for _ in range(2)],
+ cell = rnn_cell_impl.MultiRNNCell(
+ [
+ rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ for _ in range(2)
+ ],
state_is_tuple=False)
g, out_m = cell(x, m)
expected_variable_names = [
- "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
- % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
- "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
- % core_rnn_cell_impl._BIAS_VARIABLE_NAME,
- "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
- % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
- "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
- % core_rnn_cell_impl._BIAS_VARIABLE_NAME]
+ "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
+ rnn_cell_impl._BIAS_VARIABLE_NAME
+ ]
self.assertEqual(
expected_variable_names, [v.name for v in cell.trainable_variables])
self.assertFalse(cell.non_trainable_variables)
@@ -185,8 +185,7 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros(
[1, 3]) # Test BasicLSTMCell with input_size != num_units.
m = array_ops.zeros([1, 4])
- g, out_m = core_rnn_cell_impl.BasicLSTMCell(
- 2, state_is_tuple=False)(x, m)
+ g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
[g, out_m],
@@ -206,7 +205,7 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros([batch_size, input_size])
m = array_ops.zeros([batch_size - 1, state_size])
with self.assertRaises(ValueError):
- g, out_m = core_rnn_cell_impl.BasicLSTMCell(
+ g, out_m = rnn_cell_impl.BasicLSTMCell(
num_units, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
sess.run([g, out_m],
@@ -225,7 +224,7 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros([batch_size, input_size])
m = array_ops.zeros([batch_size, state_size])
with self.assertRaises(ValueError):
- g, out_m = core_rnn_cell_impl.BasicLSTMCell(
+ g, out_m = rnn_cell_impl.BasicLSTMCell(
num_units, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
sess.run([g, out_m],
@@ -239,31 +238,29 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros([1, 2])
m0 = (array_ops.zeros([1, 2]),) * 2
m1 = (array_ops.zeros([1, 2]),) * 2
- cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
+ cell = rnn_cell_impl.MultiRNNCell(
+ [rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
state_is_tuple=True)
self.assertTrue(isinstance(cell.state_size, tuple))
self.assertTrue(
- isinstance(cell.state_size[0], core_rnn_cell_impl.LSTMStateTuple))
+ isinstance(cell.state_size[0], rnn_cell_impl.LSTMStateTuple))
self.assertTrue(
- isinstance(cell.state_size[1], core_rnn_cell_impl.LSTMStateTuple))
+ isinstance(cell.state_size[1], rnn_cell_impl.LSTMStateTuple))
# Pass in regular tuples
_, (out_m0, out_m1) = cell(x, (m0, m1))
- self.assertTrue(isinstance(out_m0, core_rnn_cell_impl.LSTMStateTuple))
- self.assertTrue(isinstance(out_m1, core_rnn_cell_impl.LSTMStateTuple))
+ self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple))
+ self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
# Pass in LSTMStateTuples
variable_scope.get_variable_scope().reuse_variables()
zero_state = cell.zero_state(1, dtypes.float32)
self.assertTrue(isinstance(zero_state, tuple))
- self.assertTrue(
- isinstance(zero_state[0], core_rnn_cell_impl.LSTMStateTuple))
- self.assertTrue(
- isinstance(zero_state[1], core_rnn_cell_impl.LSTMStateTuple))
+ self.assertTrue(isinstance(zero_state[0], rnn_cell_impl.LSTMStateTuple))
+ self.assertTrue(isinstance(zero_state[1], rnn_cell_impl.LSTMStateTuple))
_, (out_m0, out_m1) = cell(x, zero_state)
- self.assertTrue(isinstance(out_m0, core_rnn_cell_impl.LSTMStateTuple))
- self.assertTrue(isinstance(out_m1, core_rnn_cell_impl.LSTMStateTuple))
+ self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple))
+ self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
def testBasicLSTMCellWithStateTuple(self):
with self.test_session() as sess:
@@ -272,9 +269,11 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros([1, 2])
m0 = array_ops.zeros([1, 4])
m1 = array_ops.zeros([1, 4])
- cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.BasicLSTMCell(
- 2, state_is_tuple=False) for _ in range(2)],
+ cell = rnn_cell_impl.MultiRNNCell(
+ [
+ rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ for _ in range(2)
+ ],
state_is_tuple=True)
g, (out_m0, out_m1) = cell(x, (m0, m1))
sess.run([variables_lib.global_variables_initializer()])
@@ -306,7 +305,7 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([batch_size, input_size])
m = array_ops.zeros([batch_size, state_size])
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell_impl.LSTMCell(
num_units=num_units,
num_proj=num_proj,
forget_bias=1.0,
@@ -340,7 +339,7 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([batch_size, input_size])
m = array_ops.zeros([batch_size, state_size])
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell_impl.LSTMCell(
num_units=num_units,
num_proj=num_proj,
forget_bias=1.0,
@@ -358,8 +357,7 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3])
- cell = core_rnn_cell_impl.OutputProjectionWrapper(
- core_rnn_cell_impl.GRUCell(3), 2)
+ cell = contrib_rnn.OutputProjectionWrapper(rnn_cell_impl.GRUCell(3), 2)
g, new_m = cell(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run([g, new_m], {
@@ -376,8 +374,8 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 3])
- cell = core_rnn_cell_impl.InputProjectionWrapper(
- core_rnn_cell_impl.GRUCell(3), num_proj=3)
+ cell = contrib_rnn.InputProjectionWrapper(
+ rnn_cell_impl.GRUCell(3), num_proj=3)
g, new_m = cell(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
@@ -394,10 +392,10 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3])
- base_cell = core_rnn_cell_impl.GRUCell(3)
+ base_cell = rnn_cell_impl.GRUCell(3)
g, m_new = base_cell(x, m)
variable_scope.get_variable_scope().reuse_variables()
- g_res, m_new_res = core_rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
+ g_res, m_new_res = rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run([g, g_res, m_new, m_new_res], {
x: np.array([[1., 1., 1.]]),
@@ -413,8 +411,7 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3])
- cell = core_rnn_cell_impl.DeviceWrapper(
- core_rnn_cell_impl.GRUCell(3), "/cpu:14159")
+ cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/cpu:14159")
outputs, _ = cell(x, m)
self.assertTrue("cpu:14159" in outputs.device.lower())
@@ -427,8 +424,7 @@ class RNNCellTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 1, 3])
- cell = core_rnn_cell_impl.DeviceWrapper(
- core_rnn_cell_impl.GRUCell(3), "/gpu:0")
+ cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/gpu:0")
with ops.device("/cpu:0"):
outputs, _ = rnn.dynamic_rnn(
cell=cell, inputs=x, dtype=dtypes.float32)
@@ -446,39 +442,14 @@ class RNNCellTest(test.TestCase):
self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
- # def testUsingSecondCellInScopeWithExistingVariablesFails(self):
- # # This test should go away when this behavior is no longer an
- # # error (Approx. May 2017)
- # cell1 = core_rnn_cell_impl.LSTMCell(3)
- # cell2 = core_rnn_cell_impl.LSTMCell(3)
- # x = array_ops.zeros([1, 3])
- # m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
- # cell1(x, m)
- # with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"):
- # cell2(x, m)
-
- # def testUsingCellInDifferentScopeFromFirstCallFails(self):
- # # This test should go away when this behavior is no longer an
- # # error (Approx. May 2017)
- # cell = core_rnn_cell_impl.LSTMCell(3)
- # x = array_ops.zeros([1, 3])
- # m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
- # with variable_scope.variable_scope("scope1"):
- # cell(x, m)
- # with variable_scope.variable_scope("scope2"):
- # with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"):
- # cell(x, m)
-
def testEmbeddingWrapper(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 1], dtype=dtypes.int32)
m = array_ops.zeros([1, 2])
- embedding_cell = core_rnn_cell_impl.EmbeddingWrapper(
- core_rnn_cell_impl.GRUCell(2),
- embedding_classes=3,
- embedding_size=2)
+ embedding_cell = contrib_rnn.EmbeddingWrapper(
+ rnn_cell_impl.GRUCell(2), embedding_classes=3, embedding_size=2)
self.assertEqual(embedding_cell.output_size, 2)
g, new_m = embedding_cell(x, m)
sess.run([variables_lib.global_variables_initializer()])
@@ -495,9 +466,8 @@ class RNNCellTest(test.TestCase):
with variable_scope.variable_scope("root"):
inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64)
input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
- embedding_cell = core_rnn_cell_impl.EmbeddingWrapper(
- core_rnn_cell_impl.BasicLSTMCell(
- 1, state_is_tuple=True),
+ embedding_cell = contrib_rnn.EmbeddingWrapper(
+ rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True),
embedding_classes=1,
embedding_size=2)
outputs, _ = rnn.dynamic_rnn(
@@ -515,9 +485,9 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 4])
- _, ml = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
- state_is_tuple=False)(x, m)
+ _, ml = rnn_cell_impl.MultiRNNCell(
+ [rnn_cell_impl.GRUCell(2)
+ for _ in range(2)], state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(ml, {
x.name: np.array([[1., 1.]]),
@@ -536,13 +506,13 @@ class RNNCellTest(test.TestCase):
# Test incorrectness of state
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
- core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
- state_is_tuple=True)(x, m_bad)
+ rnn_cell_impl.MultiRNNCell(
+ [rnn_cell_impl.GRUCell(2)
+ for _ in range(2)], state_is_tuple=True)(x, m_bad)
- _, ml = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
- state_is_tuple=True)(x, m_good)
+ _, ml = rnn_cell_impl.MultiRNNCell(
+ [rnn_cell_impl.GRUCell(2)
+ for _ in range(2)], state_is_tuple=True)(x, m_good)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(ml, {
@@ -571,23 +541,23 @@ class DropoutWrapperTest(test.TestCase):
time_steps = 2
x = constant_op.constant(
[[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
- m = core_rnn_cell_impl.LSTMStateTuple(
- *[constant_op.constant([[0.1, 0.1, 0.1]],
- dtype=dtypes.float32)] * 2)
+ m = rnn_cell_impl.LSTMStateTuple(
+ *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)
+ ] * 2)
else:
x = constant_op.constant(
np.random.randn(time_steps, batch_size, 3).astype(np.float32))
- m = core_rnn_cell_impl.LSTMStateTuple(
- *[constant_op.constant([[0.1, 0.1, 0.1]] * batch_size,
- dtype=dtypes.float32)] * 2)
+ m = rnn_cell_impl.LSTMStateTuple(*[
+ constant_op.constant(
+ [[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
+ ] * 2)
outputs, final_state = rnn.dynamic_rnn(
- cell=core_rnn_cell_impl.DropoutWrapper(
- core_rnn_cell_impl.LSTMCell(3),
- dtype=x.dtype,
- **kwargs),
+ cell=rnn_cell_impl.DropoutWrapper(
+ rnn_cell_impl.LSTMCell(3), dtype=x.dtype, **kwargs),
time_major=True,
parallel_iterations=parallel_iterations,
- inputs=x, initial_state=m)
+ inputs=x,
+ initial_state=m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run([outputs, final_state])
self.assertEqual(res[0].shape, (time_steps, batch_size, 3))
@@ -775,7 +745,7 @@ class SlimRNNCellTest(test.TestCase):
m = array_ops.zeros([1, 2])
my_cell = functools.partial(basic_rnn_cell, num_units=2)
# pylint: disable=protected-access
- g, _ = core_rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
+ g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
# pylint: enable=protected-access
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
@@ -792,12 +762,12 @@ class SlimRNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
inputs = random_ops.random_uniform((batch_size, input_size))
_, initial_state = basic_rnn_cell(inputs, None, num_units)
- rnn_cell = core_rnn_cell_impl.BasicRNNCell(num_units)
+ rnn_cell = rnn_cell_impl.BasicRNNCell(num_units)
outputs, state = rnn_cell(inputs, initial_state)
variable_scope.get_variable_scope().reuse_variables()
my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
# pylint: disable=protected-access
- slim_cell = core_rnn_cell_impl._SlimRNNCell(my_cell)
+ slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
# pylint: enable=protected-access
slim_outputs, slim_state = slim_cell(inputs, initial_state)
self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index 54e3a0dadf..bf24347c43 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -24,9 +24,6 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib import rnn as rnn_lib
-from tensorflow.contrib.rnn.python.ops import core_rnn
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -38,6 +35,7 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
@@ -153,7 +151,7 @@ class RNNTest(test.TestCase):
cell = Plus1RNNCell()
inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
with self.assertRaisesRegexp(ValueError, "must be a vector"):
- core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=4)
+ rnn.static_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=4)
def testRNN(self):
cell = Plus1RNNCell()
@@ -164,7 +162,7 @@ class RNNTest(test.TestCase):
array_ops.placeholder(
dtypes.float32, shape=(batch_size, input_size))
]
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
for out, inp in zip(outputs, inputs):
self.assertEqual(out.get_shape(), inp.get_shape())
@@ -186,7 +184,7 @@ class RNNTest(test.TestCase):
def testDropout(self):
cell = Plus1RNNCell()
- full_dropout_cell = core_rnn_cell_impl.DropoutWrapper(
+ full_dropout_cell = rnn_cell.DropoutWrapper(
cell, input_keep_prob=1e-12, seed=0)
batch_size = 2
input_size = 5
@@ -196,9 +194,9 @@ class RNNTest(test.TestCase):
dtypes.float32, shape=(batch_size, input_size))
]
with variable_scope.variable_scope("share_scope"):
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
with variable_scope.variable_scope("drop_scope"):
- dropped_outputs, _ = core_rnn.static_rnn(
+ dropped_outputs, _ = rnn.static_rnn(
full_dropout_cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
for out, inp in zip(outputs, inputs):
@@ -227,7 +225,7 @@ class RNNTest(test.TestCase):
dtypes.float32, shape=(batch_size, input_size))
]
with variable_scope.variable_scope("drop_scope"):
- dynamic_outputs, dynamic_state = core_rnn.static_rnn(
+ dynamic_outputs, dynamic_state = rnn.static_rnn(
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
self.assertEqual(len(dynamic_outputs), len(inputs))
@@ -297,8 +295,7 @@ class RNNTest(test.TestCase):
array_ops.placeholder(
dtypes.float32, shape=(batch_size, input_size))
]
- return core_rnn.static_rnn(
- cell, inputs, dtype=dtypes.float32, scope=scope)
+ return rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope=scope)
self._testScope(factory, use_outer_scope=True)
self._testScope(factory, use_outer_scope=False)
@@ -319,13 +316,13 @@ class LSTMTest(test.TestCase):
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units, initializer=initializer, state_is_tuple=False)
inputs = max_length * [
array_ops.placeholder(
dtypes.float32, shape=(batch_size, input_size))
]
- outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
for out in outputs:
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
@@ -342,7 +339,7 @@ class LSTMTest(test.TestCase):
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
cell_clip=0.0,
@@ -352,7 +349,7 @@ class LSTMTest(test.TestCase):
array_ops.placeholder(
dtypes.float32, shape=(batch_size, input_size))
]
- outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
for out in outputs:
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
@@ -374,7 +371,7 @@ class LSTMTest(test.TestCase):
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, 2 * num_units)
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=False,
initializer=initializer,
@@ -384,7 +381,7 @@ class LSTMTest(test.TestCase):
dtypes.float32, shape=(batch_size, input_size))
]
with variable_scope.variable_scope("share_scope"):
- outputs, state = core_rnn.static_state_saving_rnn(
+ outputs, state = rnn.static_state_saving_rnn(
cell, inputs, state_saver=state_saver, state_name="save_lstm")
self.assertEqual(len(outputs), len(inputs))
for out in outputs:
@@ -406,7 +403,7 @@ class LSTMTest(test.TestCase):
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, num_units)
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=False,
initializer=initializer,
@@ -416,7 +413,7 @@ class LSTMTest(test.TestCase):
dtypes.float32, shape=(batch_size, input_size))
]
with variable_scope.variable_scope("share_scope"):
- outputs, state = core_rnn.static_state_saving_rnn(
+ outputs, state = rnn.static_state_saving_rnn(
cell, inputs, state_saver=state_saver, state_name=("c", "m"))
self.assertEqual(len(outputs), len(inputs))
for out in outputs:
@@ -450,14 +447,14 @@ class LSTMTest(test.TestCase):
})
def _cell(i):
- return core_rnn_cell_impl.LSTMCell(
+ return rnn_cell.LSTMCell(
num_units + i,
use_peepholes=False,
initializer=initializer,
state_is_tuple=True)
# This creates a state tuple which has 4 sub-tuples of length 2 each.
- cell = core_rnn_cell_impl.MultiRNNCell(
+ cell = rnn_cell.MultiRNNCell(
[_cell(i) for i in range(4)], state_is_tuple=True)
self.assertEqual(len(cell.state_size), 4)
@@ -471,7 +468,7 @@ class LSTMTest(test.TestCase):
state_names = (("c0", "m0"), ("c1", "m1"), ("c2", "m2"), ("c3", "m3"))
with variable_scope.variable_scope("share_scope"):
- outputs, state = core_rnn.static_state_saving_rnn(
+ outputs, state = rnn.static_state_saving_rnn(
cell, inputs, state_saver=state_saver, state_name=state_names)
self.assertEqual(len(outputs), len(inputs))
@@ -508,13 +505,13 @@ class LSTMTest(test.TestCase):
array_ops.placeholder(
dtypes.float32, shape=(None, input_size))
]
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
num_proj=num_proj,
initializer=initializer,
state_is_tuple=False)
- outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
variables_lib.global_variables_initializer().run()
@@ -535,20 +532,20 @@ class LSTMTest(test.TestCase):
array_ops.placeholder(
dtypes.float32, shape=(None, input_size))
]
- cell_notuple = core_rnn_cell_impl.LSTMCell(
+ cell_notuple = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
num_proj=num_proj,
initializer=initializer,
state_is_tuple=False)
- cell_tuple = core_rnn_cell_impl.LSTMCell(
+ cell_tuple = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
num_proj=num_proj,
initializer=initializer,
state_is_tuple=True)
with variable_scope.variable_scope("root") as scope:
- outputs_notuple, state_notuple = core_rnn.static_rnn(
+ outputs_notuple, state_notuple = rnn.static_rnn(
cell_notuple,
inputs,
dtype=dtypes.float32,
@@ -562,7 +559,7 @@ class LSTMTest(test.TestCase):
# the parameters from different RNNCell instances. Right now,
# this seems an unrealistic use case except for testing.
cell_tuple._scope = cell_notuple._scope # pylint: disable=protected-access
- outputs_tuple, state_tuple = core_rnn.static_rnn(
+ outputs_tuple, state_tuple = rnn.static_rnn(
cell_tuple,
inputs,
dtype=dtypes.float32,
@@ -603,7 +600,7 @@ class LSTMTest(test.TestCase):
dtypes.float32, shape=(None, input_size))
]
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
num_proj=num_proj,
@@ -612,7 +609,7 @@ class LSTMTest(test.TestCase):
initializer=initializer,
state_is_tuple=False)
- outputs, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
@@ -635,7 +632,7 @@ class LSTMTest(test.TestCase):
dtypes.float64, shape=(None, input_size))
]
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
num_proj=num_proj,
@@ -644,7 +641,7 @@ class LSTMTest(test.TestCase):
initializer=initializer,
state_is_tuple=False)
- outputs, _ = core_rnn.static_rnn(
+ outputs, _ = rnn.static_rnn(
cell,
inputs,
initial_state=cell.zero_state(batch_size, dtypes.float64))
@@ -672,7 +669,7 @@ class LSTMTest(test.TestCase):
]
initializer = init_ops.constant_initializer(0.001)
- cell_noshard = core_rnn_cell_impl.LSTMCell(
+ cell_noshard = rnn_cell.LSTMCell(
num_units,
num_proj=num_proj,
use_peepholes=True,
@@ -681,7 +678,7 @@ class LSTMTest(test.TestCase):
num_proj_shards=num_proj_shards,
state_is_tuple=False)
- cell_shard = core_rnn_cell_impl.LSTMCell(
+ cell_shard = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
initializer=initializer,
@@ -689,10 +686,10 @@ class LSTMTest(test.TestCase):
state_is_tuple=False)
with variable_scope.variable_scope("noshard_scope"):
- outputs_noshard, state_noshard = core_rnn.static_rnn(
+ outputs_noshard, state_noshard = rnn.static_rnn(
cell_noshard, inputs, dtype=dtypes.float32)
with variable_scope.variable_scope("shard_scope"):
- outputs_shard, state_shard = core_rnn.static_rnn(
+ outputs_shard, state_shard = rnn.static_rnn(
cell_shard, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs_noshard), len(inputs))
@@ -731,7 +728,7 @@ class LSTMTest(test.TestCase):
dtypes.float64, shape=(None, input_size))
]
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
num_proj=num_proj,
@@ -739,9 +736,9 @@ class LSTMTest(test.TestCase):
num_proj_shards=num_proj_shards,
initializer=initializer,
state_is_tuple=False)
- dropout_cell = core_rnn_cell_impl.DropoutWrapper(cell, 0.5, seed=0)
+ dropout_cell = rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
- outputs, state = core_rnn.static_rnn(
+ outputs, state = rnn.static_rnn(
dropout_cell,
inputs,
sequence_length=sequence_length,
@@ -776,13 +773,13 @@ class LSTMTest(test.TestCase):
array_ops.placeholder(
dtypes.float32, shape=(None, input_size))
]
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
num_proj=num_proj,
initializer=initializer,
state_is_tuple=False)
- cell_d = core_rnn_cell_impl.LSTMCell(
+ cell_d = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
num_proj=num_proj,
@@ -790,11 +787,11 @@ class LSTMTest(test.TestCase):
state_is_tuple=False)
with variable_scope.variable_scope("share_scope"):
- outputs0, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
with variable_scope.variable_scope("share_scope", reuse=True):
- outputs1, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
with variable_scope.variable_scope("diff_scope"):
- outputs2, _ = core_rnn.static_rnn(cell_d, inputs, dtype=dtypes.float32)
+ outputs2, _ = rnn.static_rnn(cell_d, inputs, dtype=dtypes.float32)
variables_lib.global_variables_initializer().run()
input_value = np.random.randn(batch_size, input_size)
@@ -823,7 +820,7 @@ class LSTMTest(test.TestCase):
array_ops.placeholder(
dtypes.float32, shape=(None, input_size))
]
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
num_proj=num_proj,
@@ -832,10 +829,10 @@ class LSTMTest(test.TestCase):
with ops_lib.name_scope("scope0"):
with variable_scope.variable_scope("share_scope"):
- outputs0, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
with ops_lib.name_scope("scope1"):
with variable_scope.variable_scope("share_scope", reuse=True):
- outputs1, _ = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
variables_lib.global_variables_initializer().run()
input_value = np.random.randn(batch_size, input_size)
@@ -881,7 +878,7 @@ class LSTMTest(test.TestCase):
def testDynamicRNNAllowsUnknownTimeDimension(self):
inputs = array_ops.placeholder(dtypes.float32, shape=[1, None, 20])
- cell = core_rnn_cell.GRUCell(30)
+ cell = rnn_cell.GRUCell(30)
# Smoke test, this should not raise an error
rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
@@ -900,14 +897,14 @@ class LSTMTest(test.TestCase):
dtypes.float32, shape=(None, input_size))
]
inputs_c = array_ops.stack(inputs)
- cell = core_rnn_cell.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
num_proj=num_proj,
initializer=initializer,
state_is_tuple=True)
with variable_scope.variable_scope("root") as scope:
- outputs_static, state_static = core_rnn.static_rnn(
+ outputs_static, state_static = rnn.static_rnn(
cell,
inputs,
dtype=dtypes.float32,
@@ -921,8 +918,8 @@ class LSTMTest(test.TestCase):
time_major=True,
sequence_length=sequence_length,
scope=scope)
- self.assertTrue(isinstance(state_static, core_rnn_cell.LSTMStateTuple))
- self.assertTrue(isinstance(state_dynamic, core_rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(state_static, rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(state_dynamic, rnn_cell.LSTMStateTuple))
self.assertEqual(state_static[0], state_static.c)
self.assertEqual(state_static[1], state_static.h)
self.assertEqual(state_dynamic[0], state_dynamic.c)
@@ -960,7 +957,7 @@ class LSTMTest(test.TestCase):
inputs_c = array_ops.stack(inputs)
def _cell(i):
- return core_rnn_cell.LSTMCell(
+ return rnn_cell.LSTMCell(
num_units + i,
use_peepholes=True,
num_proj=num_proj + i,
@@ -968,7 +965,7 @@ class LSTMTest(test.TestCase):
state_is_tuple=True)
# This creates a state tuple which has 4 sub-tuples of length 2 each.
- cell = core_rnn_cell.MultiRNNCell(
+ cell = rnn_cell.MultiRNNCell(
[_cell(i) for i in range(4)], state_is_tuple=True)
self.assertEqual(len(cell.state_size), 4)
@@ -982,7 +979,7 @@ class LSTMTest(test.TestCase):
self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1])
with variable_scope.variable_scope("root") as scope:
- outputs_static, state_static = core_rnn.static_rnn(
+ outputs_static, state_static = rnn.static_rnn(
cell,
inputs,
dtype=dtypes.float32,
@@ -1034,7 +1031,7 @@ class LSTMTest(test.TestCase):
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
- cell = core_rnn_cell.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
initializer=initializer,
@@ -1042,7 +1039,7 @@ class LSTMTest(test.TestCase):
state_is_tuple=False)
with variable_scope.variable_scope("dynamic_scope"):
- outputs_static, state_static = core_rnn.static_rnn(
+ outputs_static, state_static = rnn.static_rnn(
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
feeds = {concat_inputs: input_values}
@@ -1092,7 +1089,7 @@ class LSTMTest(test.TestCase):
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
- cell = core_rnn_cell.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
initializer=initializer,
@@ -1205,16 +1202,16 @@ class BidirectionalRNNTest(test.TestCase):
-0.01, 0.01, seed=self._seed)
sequence_length = array_ops.placeholder(
dtypes.int64) if use_sequence_length else None
- cell_fw = core_rnn_cell_impl.LSTMCell(
+ cell_fw = rnn_cell.LSTMCell(
num_units, input_size, initializer=initializer, state_is_tuple=False)
- cell_bw = core_rnn_cell_impl.LSTMCell(
+ cell_bw = rnn_cell.LSTMCell(
num_units, input_size, initializer=initializer, state_is_tuple=False)
inputs = max_length * [
array_ops.placeholder(
dtypes.float32,
shape=(batch_size, input_size) if use_shape else (None, input_size))
]
- outputs, state_fw, state_bw = core_rnn.static_bidirectional_rnn(
+ outputs, state_fw, state_bw = rnn.static_bidirectional_rnn(
cell_fw,
cell_bw,
inputs,
@@ -1337,9 +1334,9 @@ class BidirectionalRNNTest(test.TestCase):
-0.01, 0.01, seed=self._seed)
sequence_length = (
array_ops.placeholder(dtypes.int64) if use_sequence_length else None)
- cell_fw = core_rnn_cell.LSTMCell(
+ cell_fw = rnn_cell.LSTMCell(
num_units, initializer=initializer, state_is_tuple=use_state_tuple)
- cell_bw = core_rnn_cell.LSTMCell(
+ cell_bw = rnn_cell.LSTMCell(
num_units, initializer=initializer, state_is_tuple=use_state_tuple)
inputs = max_length * [
array_ops.placeholder(
@@ -1530,7 +1527,7 @@ class MultiDimensionalLSTMTest(test.TestCase):
# variables.
cell = DummyMultiDimensionalLSTM(feature_dims)
state_saver = TestStateSaver(batch_size, input_size)
- outputs_static, state_static = core_rnn.static_rnn(
+ outputs_static, state_static = rnn.static_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length)
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
cell,
@@ -1538,13 +1535,13 @@ class MultiDimensionalLSTMTest(test.TestCase):
dtype=dtypes.float32,
time_major=True,
sequence_length=sequence_length)
- outputs_bid, state_fw, state_bw = core_rnn.static_bidirectional_rnn(
+ outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn(
cell,
cell,
inputs_using_dim,
dtype=dtypes.float32,
sequence_length=sequence_length)
- outputs_sav, state_sav = core_rnn.static_state_saving_rnn(
+ outputs_sav, state_sav = rnn.static_state_saving_rnn(
cell,
inputs_using_dim,
sequence_length=sequence_length,
@@ -1634,15 +1631,15 @@ class NestedLSTMTest(test.TestCase):
dtype=dtypes.float32,
time_major=True,
sequence_length=sequence_length)
- outputs_static, state_static = core_rnn.static_rnn(
+ outputs_static, state_static = rnn.static_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length)
- outputs_bid, state_fw, state_bw = core_rnn.static_bidirectional_rnn(
+ outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn(
cell,
cell,
inputs_using_dim,
dtype=dtypes.float32,
sequence_length=sequence_length)
- outputs_sav, state_sav = core_rnn.static_state_saving_rnn(
+ outputs_sav, state_sav = rnn.static_state_saving_rnn(
cell,
inputs_using_dim,
sequence_length=sequence_length,
@@ -1738,7 +1735,7 @@ class StateSaverRNNTest(test.TestCase):
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, 2 * num_units)
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=False,
initializer=initializer,
@@ -1747,7 +1744,7 @@ class StateSaverRNNTest(test.TestCase):
array_ops.placeholder(
dtypes.float32, shape=(batch_size, input_size))
]
- return core_rnn.static_state_saving_rnn(
+ return rnn.static_state_saving_rnn(
cell,
inputs,
state_saver=state_saver,
@@ -1779,7 +1776,7 @@ class GRUTest(test.TestCase):
concat_inputs = array_ops.placeholder(
dtypes.float32, shape=(time_steps, batch_size, input_size))
- cell = core_rnn_cell.GRUCell(num_units=num_units)
+ cell = rnn_cell.GRUCell(num_units=num_units)
with variable_scope.variable_scope("dynamic_scope"):
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
@@ -1830,7 +1827,7 @@ class GRUTest(test.TestCase):
def factory(scope):
concat_inputs = array_ops.placeholder(
dtypes.float32, shape=(time_steps, batch_size, input_size))
- cell = core_rnn_cell.GRUCell(num_units=num_units)
+ cell = rnn_cell.GRUCell(num_units=num_units)
return rnn.dynamic_rnn(
cell,
inputs=concat_inputs,
@@ -1864,7 +1861,7 @@ class RawRNNTest(test.TestCase):
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
inputs_ta = inputs_ta.unstack(inputs)
- cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
+ cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
def loop_fn(time_, cell_output, cell_state, unused_loop_state):
emit_output = cell_output # == None for time == 0
@@ -1965,7 +1962,7 @@ class RawRNNTest(test.TestCase):
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
inputs_ta = inputs_ta.unstack(inputs)
- cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
+ cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
def loop_fn(time_, cell_output, cell_state, loop_state):
if cell_output is None:
@@ -2001,7 +1998,7 @@ class RawRNNTest(test.TestCase):
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
inputs_ta = inputs_ta.unstack(inputs)
- cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
+ cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
def loop_fn(time_, cell_output, cell_state, loop_state):
if cell_output is None:
@@ -2044,7 +2041,7 @@ class RawRNNTest(test.TestCase):
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
inputs_ta = inputs_ta.unstack(inputs)
- cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
+ cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
def loop_fn(time_, cell_output, cell_state, _):
if cell_output is None:
@@ -2113,7 +2110,7 @@ class RawRNNTest(test.TestCase):
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
inputs_ta = inputs_ta.unstack(inputs)
- cell = core_rnn_cell.LSTMCell(num_units, state_is_tuple=True)
+ cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
def loop_fn(time_, cell_output, cell_state, unused_loop_state):
emit_output = cell_output # == None for time == 0
@@ -2138,7 +2135,7 @@ class RawRNNTest(test.TestCase):
self._testScope(factory, prefix=None, use_outer_scope=False)
-class DeviceWrapperCell(core_rnn_cell.RNNCell):
+class DeviceWrapperCell(rnn_cell.RNNCell):
"""Class to ensure cell calculation happens on a specific device."""
def __init__(self, cell, device):
@@ -2172,7 +2169,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
input_size = 5
num_units = 10
- cell = core_rnn_cell.LSTMCell(num_units, use_peepholes=True)
+ cell = rnn_cell.LSTMCell(num_units, use_peepholes=True)
gpu_cell = DeviceWrapperCell(cell, cell_device)
inputs = np.random.randn(batch_size, time_steps,
input_size).astype(np.float32)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
index a656831c02..f2a032e41e 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
@@ -20,14 +20,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.rnn.python.ops import core_rnn
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.contrib.rnn.python.ops import fused_rnn_cell
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -41,7 +41,7 @@ class FusedRnnCellTest(test.TestCase):
with self.test_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890212)
- cell = core_rnn_cell_impl.BasicRNNCell(10)
+ cell = rnn_cell.BasicRNNCell(10)
batch_size = 5
input_size = 20
timelen = 15
@@ -49,7 +49,7 @@ class FusedRnnCellTest(test.TestCase):
np.random.randn(timelen, batch_size, input_size))
with variable_scope.variable_scope("basic", initializer=initializer):
unpacked_inputs = array_ops.unstack(inputs)
- outputs, state = core_rnn.static_rnn(
+ outputs, state = rnn.static_rnn(
cell, unpacked_inputs, dtype=dtypes.float64)
packed_outputs = array_ops.stack(outputs)
basic_vars = [
@@ -65,7 +65,7 @@ class FusedRnnCellTest(test.TestCase):
with variable_scope.variable_scope(
"fused_static", initializer=initializer):
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
- core_rnn_cell_impl.BasicRNNCell(10))
+ rnn_cell.BasicRNNCell(10))
outputs, state = fused_cell(inputs, dtype=dtypes.float64)
fused_static_vars = [
v for v in variables.trainable_variables()
@@ -86,7 +86,7 @@ class FusedRnnCellTest(test.TestCase):
with variable_scope.variable_scope(
"fused_dynamic", initializer=initializer):
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
- core_rnn_cell_impl.BasicRNNCell(10), use_dynamic_rnn=True)
+ rnn_cell.BasicRNNCell(10), use_dynamic_rnn=True)
outputs, state = fused_cell(inputs, dtype=dtypes.float64)
fused_dynamic_vars = [
v for v in variables.trainable_variables()
@@ -109,8 +109,8 @@ class FusedRnnCellTest(test.TestCase):
with self.test_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890213)
- fw_cell = core_rnn_cell_impl.BasicRNNCell(10)
- bw_cell = core_rnn_cell_impl.BasicRNNCell(10)
+ fw_cell = rnn_cell.BasicRNNCell(10)
+ bw_cell = rnn_cell.BasicRNNCell(10)
batch_size = 5
input_size = 20
timelen = 15
@@ -120,7 +120,7 @@ class FusedRnnCellTest(test.TestCase):
# test bi-directional rnn
with variable_scope.variable_scope("basic", initializer=initializer):
unpacked_inputs = array_ops.unstack(inputs)
- outputs, fw_state, bw_state = core_rnn.static_bidirectional_rnn(
+ outputs, fw_state, bw_state = rnn.static_bidirectional_rnn(
fw_cell, bw_cell, unpacked_inputs, dtype=dtypes.float64)
packed_outputs = array_ops.stack(outputs)
basic_vars = [
@@ -136,10 +136,9 @@ class FusedRnnCellTest(test.TestCase):
with variable_scope.variable_scope("fused", initializer=initializer):
fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
- core_rnn_cell_impl.BasicRNNCell(10))
+ rnn_cell.BasicRNNCell(10))
fused_bw_cell = fused_rnn_cell.TimeReversedFusedRNN(
- fused_rnn_cell.FusedRNNCellAdaptor(
- core_rnn_cell_impl.BasicRNNCell(10)))
+ fused_rnn_cell.FusedRNNCellAdaptor(rnn_cell.BasicRNNCell(10)))
fw_outputs, fw_state = fused_cell(
inputs, dtype=dtypes.float64, scope="fw")
bw_outputs, bw_state = fused_bw_cell(
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py
index 4247aeb839..d2ec648537 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py
@@ -22,7 +22,6 @@ import time
import numpy as np
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.contrib.rnn.python.ops import gru_ops
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
@@ -33,6 +32,7 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -78,7 +78,7 @@ class GRUBlockCellTest(test.TestCase):
# Output from the basic GRU cell implementation.
with vs.variable_scope("basic", initializer=initializer):
- output = core_rnn_cell_impl.GRUCell(cell_size)(x, h)
+ output = rnn_cell.GRUCell(cell_size)(x, h)
sess.run([variables.global_variables_initializer()])
basic_res = sess.run([output], {x: x_value, h: h_value})
@@ -128,7 +128,7 @@ class GRUBlockCellTest(test.TestCase):
# Output from the basic GRU cell implementation.
with vs.variable_scope("basic", initializer=initializer):
- cell = core_rnn_cell_impl.GRUCell(cell_size)
+ cell = rnn_cell.GRUCell(cell_size)
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
cell,
inputs=concat_x,
@@ -192,7 +192,7 @@ class GRUBlockCellTest(test.TestCase):
# Gradients from the basic GRU cell implementation.
with vs.variable_scope("basic", initializer=initializer):
- output = core_rnn_cell_impl.GRUCell(cell_size)(x, h)
+ output = rnn_cell.GRUCell(cell_size)(x, h)
sess.run([variables.global_variables_initializer()])
all_variables = variables.global_variables()[4:8]
@@ -258,7 +258,7 @@ class GRUBlockCellTest(test.TestCase):
# Gradients from the basic GRU cell implementation.
with vs.variable_scope("basic", initializer=initializer):
- cell = core_rnn_cell_impl.GRUCell(cell_size)
+ cell = rnn_cell.GRUCell(cell_size)
outputs_dynamic, _ = rnn.dynamic_rnn(
cell,
@@ -377,7 +377,7 @@ def training_gru_block_vs_gru_cell(batch_size,
# Output from the basic GRU cell implementation.
with vs.variable_scope("basic", initializer=initializer):
- cell = core_rnn_cell_impl.GRUCell(cell_size)
+ cell = rnn_cell.GRUCell(cell_size)
outputs_dynamic, _ = rnn.dynamic_rnn(
cell,
@@ -448,7 +448,7 @@ def inference_gru_block_vs_gru_cell(batch_size,
# Output from the basic GRU cell implementation.
with vs.variable_scope("basic", initializer=initializer):
- cell = core_rnn_cell_impl.GRUCell(cell_size)
+ cell = rnn_cell.GRUCell(cell_size)
outputs_dynamic, _ = rnn.dynamic_rnn(
cell,
inputs=concat_x,
@@ -497,8 +497,8 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size,
# Output from the basic GRU cell implementation.
with vs.variable_scope("basic", initializer=initializer):
- output = core_rnn_cell_impl.GRUCell(cell_size)(array_ops.identity(x),
- array_ops.identity(h))
+ output = rnn_cell.GRUCell(cell_size)(array_ops.identity(x),
+ array_ops.identity(h))
sess.run([variables.global_variables_initializer()])
grad_output_wrt_input = gradients_impl.gradients([output], h)
basic_time_bprop = time_taken_by_op(grad_output_wrt_input, sess, iters)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
index 3a5cbf604d..1e6c44a115 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
@@ -20,8 +20,6 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.rnn.python.ops import core_rnn
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -30,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -66,10 +65,9 @@ class LSTMBlockCellTest(test.TestCase):
m1 = array_ops.zeros([1, 2])
m2 = array_ops.zeros([1, 2])
m3 = array_ops.zeros([1, 2])
- g, ((out_m0, out_m1),
- (out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
- [lstm_ops.LSTMBlockCell(2) for _ in range(2)],
- state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
+ g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
+ [lstm_ops.LSTMBlockCell(2)
+ for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
sess.run([variables.global_variables_initializer()])
res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
x.name: np.array([[1., 1.]]),
@@ -88,11 +86,11 @@ class LSTMBlockCellTest(test.TestCase):
def testCompatibleNames(self):
with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()):
- cell = core_rnn_cell_impl.LSTMCell(10)
- pcell = core_rnn_cell_impl.LSTMCell(10, use_peepholes=True)
+ cell = rnn_cell.LSTMCell(10)
+ pcell = rnn_cell.LSTMCell(10, use_peepholes=True)
inputs = [array_ops.zeros([4, 5])] * 6
- core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
- core_rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
+ rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
+ rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
basic_names = {
v.name: v.get_shape()
for v in variables.trainable_variables()
@@ -102,8 +100,8 @@ class LSTMBlockCellTest(test.TestCase):
cell = lstm_ops.LSTMBlockCell(10)
pcell = lstm_ops.LSTMBlockCell(10, use_peephole=True)
inputs = [array_ops.zeros([4, 5])] * 6
- core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
- core_rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
+ rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
+ rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
block_names = {
v.name: v.get_shape()
for v in variables.trainable_variables()
@@ -140,11 +138,9 @@ class LSTMBlockCellTest(test.TestCase):
m1 = array_ops.zeros([1, 2])
m2 = array_ops.zeros([1, 2])
m3 = array_ops.zeros([1, 2])
- g, ((out_m0, out_m1),
- (out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
- for _ in range(2)],
- state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
+ g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
+ [rnn_cell.BasicLSTMCell(2, state_is_tuple=True) for _ in range(2)],
+ state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
sess.run([variables.global_variables_initializer()])
basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
x.name: x_values,
@@ -159,10 +155,9 @@ class LSTMBlockCellTest(test.TestCase):
m1 = array_ops.zeros([1, 2])
m2 = array_ops.zeros([1, 2])
m3 = array_ops.zeros([1, 2])
- g, ((out_m0, out_m1),
- (out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
- [lstm_ops.LSTMBlockCell(2) for _ in range(2)],
- state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
+ g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
+ [lstm_ops.LSTMBlockCell(2)
+ for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
sess.run([variables.global_variables_initializer()])
block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
x.name: x_values,
@@ -193,12 +188,12 @@ class LSTMBlockCellTest(test.TestCase):
m1 = array_ops.zeros([1, 2])
m2 = array_ops.zeros([1, 2])
m3 = array_ops.zeros([1, 2])
- g, ((out_m0, out_m1),
- (out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.LSTMCell(
- 2, use_peepholes=True, state_is_tuple=True)
- for _ in range(2)],
- state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
+ g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
+ [
+ rnn_cell.LSTMCell(2, use_peepholes=True, state_is_tuple=True)
+ for _ in range(2)
+ ],
+ state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
sess.run([variables.global_variables_initializer()])
basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
x.name: x_values,
@@ -213,11 +208,9 @@ class LSTMBlockCellTest(test.TestCase):
m1 = array_ops.zeros([1, 2])
m2 = array_ops.zeros([1, 2])
m3 = array_ops.zeros([1, 2])
- g, ((out_m0, out_m1),
- (out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
- [lstm_ops.LSTMBlockCell(2, use_peephole=True)
- for _ in range(2)],
- state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
+ g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
+ [lstm_ops.LSTMBlockCell(2, use_peephole=True) for _ in range(2)],
+ state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
sess.run([variables.global_variables_initializer()])
block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
x.name: x_values,
@@ -247,8 +240,8 @@ class LSTMBlockCellTest(test.TestCase):
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890212)
with variable_scope.variable_scope("basic", initializer=initializer):
- cell = core_rnn_cell_impl.BasicLSTMCell(cell_size, state_is_tuple=True)
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
sess.run([variables.global_variables_initializer()])
basic_outputs, basic_state = sess.run([outputs, state[0]])
@@ -321,9 +314,9 @@ class LSTMBlockCellTest(test.TestCase):
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890212)
with variable_scope.variable_scope("basic", initializer=initializer):
- cell = core_rnn_cell_impl.LSTMCell(
+ cell = rnn_cell.LSTMCell(
cell_size, use_peepholes=True, state_is_tuple=True)
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
sess.run([variables.global_variables_initializer()])
basic_outputs, basic_state = sess.run([outputs, state[0]])
@@ -410,8 +403,8 @@ class LSTMBlockCellTest(test.TestCase):
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890213)
with variable_scope.variable_scope("basic", initializer=initializer):
- cell = core_rnn_cell_impl.BasicLSTMCell(cell_size, state_is_tuple=True)
- outputs, state = core_rnn.static_rnn(
+ cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
+ outputs, state = rnn.static_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=seq_lengths)
sess.run([variables.global_variables_initializer()])
basic_outputs, basic_state = sess.run([outputs, state[0]])
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 334baa5f9c..04b0c5876b 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -22,8 +22,7 @@ import itertools
import numpy as np
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
-from tensorflow.contrib.rnn.python.ops import rnn_cell
+from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
@@ -37,6 +36,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -65,7 +65,7 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([batch_size, input_size])
m = array_ops.zeros([batch_size, state_size])
- output, state = rnn_cell.CoupledInputForgetGateLSTMCell(
+ output, state = contrib_rnn_cell.CoupledInputForgetGateLSTMCell(
num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state], {
@@ -94,7 +94,7 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([batch_size, input_size])
m = array_ops.zeros([batch_size, state_size * num_shifts])
- output, state = rnn_cell.TimeFreqLSTMCell(
+ output, state = contrib_rnn_cell.TimeFreqLSTMCell(
num_units=num_units,
feature_size=feature_size,
frequency_skip=frequency_skip,
@@ -130,7 +130,7 @@ class RNNCellTest(test.TestCase):
num_shifts = int((input_size - feature_size) / frequency_skip + 1)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell = rnn_cell.GridLSTMCell(
+ cell = contrib_rnn_cell.GridLSTMCell(
num_units=num_units,
feature_size=feature_size,
frequency_skip=frequency_skip,
@@ -181,7 +181,7 @@ class RNNCellTest(test.TestCase):
end_freqindex_list = [2, 4]
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell = rnn_cell.GridLSTMCell(
+ cell = contrib_rnn_cell.GridLSTMCell(
num_units=num_units,
feature_size=feature_size,
frequency_skip=frequency_skip,
@@ -249,7 +249,7 @@ class RNNCellTest(test.TestCase):
with variable_scope.variable_scope(
"state_is_tuple" + str(state_is_tuple),
initializer=init_ops.constant_initializer(0.5)):
- cell = rnn_cell.GridLSTMCell(
+ cell = contrib_rnn_cell.GridLSTMCell(
num_units=num_units,
feature_size=feature_size,
frequency_skip=frequency_skip,
@@ -330,7 +330,7 @@ class RNNCellTest(test.TestCase):
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell = rnn_cell.BidirectionalGridLSTMCell(
+ cell = contrib_rnn_cell.BidirectionalGridLSTMCell(
num_units=num_units,
feature_size=feature_size,
share_time_frequency_weights=True,
@@ -403,7 +403,7 @@ class RNNCellTest(test.TestCase):
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell = rnn_cell.BidirectionalGridLSTMCell(
+ cell = contrib_rnn_cell.BidirectionalGridLSTMCell(
num_units=num_units,
feature_size=feature_size,
share_time_frequency_weights=True,
@@ -442,28 +442,28 @@ class RNNCellTest(test.TestCase):
def testAttentionCellWrapperFailures(self):
with self.assertRaisesRegexp(TypeError,
"The parameter cell is not RNNCell."):
- rnn_cell.AttentionCellWrapper(None, 0)
+ contrib_rnn_cell.AttentionCellWrapper(None, 0)
num_units = 8
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
+ lstm_cell = rnn_cell.BasicLSTMCell(
num_units, state_is_tuple=state_is_tuple)
with self.assertRaisesRegexp(
ValueError, "attn_length should be greater than zero, got 0"):
- rnn_cell.AttentionCellWrapper(
+ contrib_rnn_cell.AttentionCellWrapper(
lstm_cell, 0, state_is_tuple=state_is_tuple)
with self.assertRaisesRegexp(
ValueError, "attn_length should be greater than zero, got -1"):
- rnn_cell.AttentionCellWrapper(
+ contrib_rnn_cell.AttentionCellWrapper(
lstm_cell, -1, state_is_tuple=state_is_tuple)
with ops.Graph().as_default():
- lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
- num_units, state_is_tuple=True)
+ lstm_cell = rnn_cell.BasicLSTMCell(num_units, state_is_tuple=True)
with self.assertRaisesRegexp(
ValueError, "Cell returns tuple of states, but the flag "
"state_is_tuple is not set. State size is: *"):
- rnn_cell.AttentionCellWrapper(lstm_cell, 4, state_is_tuple=False)
+ contrib_rnn_cell.AttentionCellWrapper(
+ lstm_cell, 4, state_is_tuple=False)
def testAttentionCellWrapperZeros(self):
num_units = 8
@@ -475,9 +475,9 @@ class RNNCellTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope("state_is_tuple_" + str(
state_is_tuple)):
- lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
+ lstm_cell = rnn_cell.BasicLSTMCell(
num_units, state_is_tuple=state_is_tuple)
- cell = rnn_cell.AttentionCellWrapper(
+ cell = contrib_rnn_cell.AttentionCellWrapper(
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
if state_is_tuple:
zeros = array_ops.zeros([batch_size, num_units], dtype=np.float32)
@@ -526,9 +526,9 @@ class RNNCellTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope("state_is_tuple_" + str(
state_is_tuple)):
- lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
+ lstm_cell = rnn_cell.BasicLSTMCell(
num_units, state_is_tuple=state_is_tuple)
- cell = rnn_cell.AttentionCellWrapper(
+ cell = contrib_rnn_cell.AttentionCellWrapper(
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
if state_is_tuple:
zeros = constant_op.constant(
@@ -603,9 +603,9 @@ class RNNCellTest(test.TestCase):
with variable_scope.variable_scope(
"state_is_tuple", reuse=state_is_tuple,
initializer=init_ops.glorot_uniform_initializer()):
- lstm_cell = core_rnn_cell_impl.BasicLSTMCell(
+ lstm_cell = rnn_cell.BasicLSTMCell(
num_units, state_is_tuple=state_is_tuple)
- cell = rnn_cell.AttentionCellWrapper(
+ cell = contrib_rnn_cell.AttentionCellWrapper(
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
# This is legacy behavior to preserve the test. Weight
# sharing no longer works by creating a new RNNCell in the
@@ -665,8 +665,7 @@ class RNNCellTest(test.TestCase):
with variable_scope.variable_scope(
"nas_test",
initializer=init_ops.constant_initializer(0.5)):
- cell = rnn_cell.NASCell(
- num_units=num_units)
+ cell = contrib_rnn_cell.NASCell(num_units=num_units)
inputs = constant_op.constant(
np.array([[1., 1., 1., 1.],
[2., 2., 2., 2.],
@@ -677,8 +676,7 @@ class RNNCellTest(test.TestCase):
0.1 * np.ones(
(batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = core_rnn_cell_impl.LSTMStateTuple(state_value,
- state_value)
+ init_state = rnn_cell.LSTMStateTuple(state_value, state_value)
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state])
@@ -719,9 +717,7 @@ class RNNCellTest(test.TestCase):
with variable_scope.variable_scope(
"nas_proj_test",
initializer=init_ops.constant_initializer(0.5)):
- cell = rnn_cell.NASCell(
- num_units=num_units,
- num_proj=num_proj)
+ cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
inputs = constant_op.constant(
np.array([[1., 1., 1., 1.],
[2., 2., 2., 2.],
@@ -736,8 +732,7 @@ class RNNCellTest(test.TestCase):
0.1 * np.ones(
(batch_size, num_proj), dtype=np.float32),
dtype=dtypes.float32)
- init_state = core_rnn_cell_impl.LSTMStateTuple(state_value_c,
- state_value_h)
+ init_state = rnn_cell.LSTMStateTuple(state_value_c, state_value_h)
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state])
@@ -767,7 +762,7 @@ class RNNCellTest(test.TestCase):
with variable_scope.variable_scope(
"ugrnn_cell_test",
initializer=init_ops.constant_initializer(0.5)):
- cell = rnn_cell.UGRNNCell(num_units=num_units)
+ cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
inputs = constant_op.constant(
np.array([[1., 1., 1., 1.],
[2., 2., 2., 2.],
@@ -803,8 +798,8 @@ class RNNCellTest(test.TestCase):
with variable_scope.variable_scope(
"intersection_rnn_cell_test",
initializer=init_ops.constant_initializer(0.5)):
- cell = rnn_cell.IntersectionRNNCell(num_units=num_units,
- num_in_proj=num_units)
+ cell = contrib_rnn_cell.IntersectionRNNCell(
+ num_units=num_units, num_in_proj=num_units)
inputs = constant_op.constant(
np.array([[1., 1., 1., 1.],
[2., 2., 2., 2.],
@@ -826,7 +821,7 @@ class RNNCellTest(test.TestCase):
def testIntersectionRNNCellFailure(self):
num_units = 2
batch_size = 3
- cell = rnn_cell.IntersectionRNNCell(num_units=num_units)
+ cell = contrib_rnn_cell.IntersectionRNNCell(num_units=num_units)
inputs = constant_op.constant(
np.array([[1., 1., 1., 1.],
[2., 2., 2., 2.],
@@ -862,9 +857,9 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros([batch_size, input_size])
c0 = array_ops.zeros([batch_size, 2])
h0 = array_ops.zeros([batch_size, 2])
- state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
- output, state = rnn_cell.PhasedLSTMCell(num_units=num_units)((t, x),
- state0)
+ state0 = rnn_cell.LSTMStateTuple(c0, h0)
+ output, state = contrib_rnn_cell.PhasedLSTMCell(num_units=num_units)(
+ (t, x), state0)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state], {
t.name:
@@ -886,12 +881,12 @@ class RNNCellTest(test.TestCase):
"base_cell", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3])
- base_cell = core_rnn_cell_impl.GRUCell(3)
+ base_cell = rnn_cell.GRUCell(3)
g, m_new = base_cell(x, m)
with variable_scope.variable_scope(
"hw_cell", initializer=init_ops.constant_initializer(0.5)):
- hw_cell = rnn_cell.HighwayWrapper(
- core_rnn_cell_impl.GRUCell(3), carry_bias_init=-100.0)
+ hw_cell = contrib_rnn_cell.HighwayWrapper(
+ rnn_cell.GRUCell(3), carry_bias_init=-100.0)
g_res, m_new_res = hw_cell(x, m)
sess.run([variables.global_variables_initializer()])
res = sess.run([g, g_res, m_new, m_new_res], {
@@ -915,9 +910,9 @@ class RNNCellTest(test.TestCase):
"root1", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.ones([batch_size, num_units])
# When number_of_groups = 1, G-LSTM is equivalent to regular LSTM
- gcell = rnn_cell.GLSTMCell(num_units=num_units,
- number_of_groups=number_of_groups)
- cell = core_rnn_cell_impl.LSTMCell(num_units=num_units)
+ gcell = contrib_rnn_cell.GLSTMCell(
+ num_units=num_units, number_of_groups=number_of_groups)
+ cell = rnn_cell.LSTMCell(num_units=num_units)
self.assertTrue(isinstance(gcell.state_size, tuple))
zero_state = gcell.zero_state(batch_size=batch_size,
dtype=dtypes.float32)
@@ -941,8 +936,8 @@ class RNNCellTest(test.TestCase):
"root2", initializer=init_ops.constant_initializer(0.5)):
# input for G-LSTM with 2 groups
glstm_input = array_ops.ones([batch_size, num_units])
- gcell = rnn_cell.GLSTMCell(num_units=num_units,
- number_of_groups=number_of_groups)
+ gcell = contrib_rnn_cell.GLSTMCell(
+ num_units=num_units, number_of_groups=number_of_groups)
gcell_zero_state = gcell.zero_state(batch_size=batch_size,
dtype=dtypes.float32)
gh, gs = gcell(glstm_input, gcell_zero_state)
@@ -950,8 +945,7 @@ class RNNCellTest(test.TestCase):
# input for LSTM cell simulating single G-LSTM group
lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
# note division by number_of_groups. This cell one simulates G-LSTM group
- cell = core_rnn_cell_impl.LSTMCell(num_units=
- int(num_units / number_of_groups))
+ cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups))
cell_zero_state = cell.zero_state(batch_size=batch_size,
dtype=dtypes.float32)
h, g = cell(lstm_input, cell_zero_state)
@@ -974,13 +968,13 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
x = array_ops.zeros([1, 2])
c0 = array_ops.zeros([1, 2])
h0 = array_ops.zeros([1, 2])
- state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
+ state0 = rnn_cell.LSTMStateTuple(c0, h0)
c1 = array_ops.zeros([1, 2])
h1 = array_ops.zeros([1, 2])
- state1 = core_rnn_cell_impl.LSTMStateTuple(c1, h1)
+ state1 = rnn_cell.LSTMStateTuple(c1, h1)
state = (state0, state1)
- single_cell = lambda: rnn_cell.LayerNormBasicLSTMCell(2)
- cell = core_rnn_cell_impl.MultiRNNCell([single_cell() for _ in range(2)])
+ single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2)
+ cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
res = sess.run([g, out_m], {
@@ -1015,8 +1009,8 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
[1, 3]) # Test BasicLSTMCell with input_size != num_units.
c = array_ops.zeros([1, 2])
h = array_ops.zeros([1, 2])
- state = core_rnn_cell_impl.LSTMStateTuple(c, h)
- cell = rnn_cell.LayerNormBasicLSTMCell(2)
+ state = rnn_cell.LSTMStateTuple(c, h)
+ cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2)
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
res = sess.run([g, out_m], {
@@ -1039,12 +1033,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
x = array_ops.zeros([1, 2])
c0 = array_ops.zeros([1, 2])
h0 = array_ops.zeros([1, 2])
- state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
+ state0 = rnn_cell.LSTMStateTuple(c0, h0)
c1 = array_ops.zeros([1, 2])
h1 = array_ops.zeros([1, 2])
- state1 = core_rnn_cell_impl.LSTMStateTuple(c1, h1)
- cell = core_rnn_cell_impl.MultiRNNCell(
- [rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)])
+ state1 = rnn_cell.LSTMStateTuple(c1, h1)
+ cell = rnn_cell.MultiRNNCell(
+ [contrib_rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)])
h, (s0, s1) = cell(x, (state0, state1))
sess.run([variables.global_variables_initializer()])
res = sess.run([h, s0, s1], {
@@ -1094,8 +1088,8 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
x = array_ops.zeros([1, 5])
c = array_ops.zeros([1, 5])
h = array_ops.zeros([1, 5])
- state = core_rnn_cell_impl.LSTMStateTuple(c, h)
- cell = rnn_cell.LayerNormBasicLSTMCell(
+ state = rnn_cell.LSTMStateTuple(c, h)
+ cell = contrib_rnn_cell.LayerNormBasicLSTMCell(
num_units, layer_norm=False, dropout_keep_prob=keep_prob)
g, s = cell(x, state)
@@ -1138,10 +1132,9 @@ def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth,
inputs = variable_scope.get_variable(
"inputs", initializer=random_ops.random_uniform(
(max_time, batch_size, input_depth), seed=1))
- maybe_xla = lambda c: rnn_cell.CompiledWrapper(c) if compiled else c
- cell = core_rnn_cell_impl.MultiRNNCell(
- [maybe_xla(core_rnn_cell_impl.LSTMCell(num_units))
- for _ in range(num_layers)])
+ maybe_xla = lambda c: contrib_rnn_cell.CompiledWrapper(c) if compiled else c
+ cell = rnn_cell.MultiRNNCell(
+ [maybe_xla(rnn_cell.LSTMCell(num_units)) for _ in range(num_layers)])
initial_state = cell.zero_state(
batch_size=batch_size, dtype=dtypes.float32)
outputs, final_state = rnn.dynamic_rnn(
@@ -1219,13 +1212,13 @@ class CompiledWrapperTest(test.TestCase):
# Test incorrectness of state
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
- core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
- state_is_tuple=True)(x, m_bad)
+ rnn_cell.MultiRNNCell(
+ [rnn_cell.GRUCell(2)
+ for _ in range(2)], state_is_tuple=True)(x, m_bad)
- _, ml = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
- state_is_tuple=True)(x, m_good)
+ _, ml = rnn_cell.MultiRNNCell(
+ [rnn_cell.GRUCell(2)
+ for _ in range(2)], state_is_tuple=True)(x, m_good)
sess.run([variables.global_variables_initializer()])
res = sess.run(ml, {
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py
index 5f96c565e8..e0d063a1b6 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py
@@ -22,12 +22,12 @@ import itertools
import numpy as np
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
-from tensorflow.contrib.rnn.python.ops import rnn
+from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -58,14 +58,14 @@ class StackBidirectionalRNNTest(test.TestCase):
dtypes.int64) if use_sequence_length else None
self.cells_fw = [
- core_rnn_cell_impl.LSTMCell(
+ rnn_cell.LSTMCell(
num_units,
input_size,
initializer=initializer,
state_is_tuple=False) for num_units in self.layers
]
self.cells_bw = [
- core_rnn_cell_impl.LSTMCell(
+ rnn_cell.LSTMCell(
num_units,
input_size,
initializer=initializer,
@@ -77,7 +77,7 @@ class StackBidirectionalRNNTest(test.TestCase):
dtypes.float32,
shape=(batch_size, input_size) if use_shape else (None, input_size))
]
- outputs, state_fw, state_bw = rnn.stack_bidirectional_rnn(
+ outputs, state_fw, state_bw = contrib_rnn.stack_bidirectional_rnn(
self.cells_fw,
self.cells_bw,
inputs,
@@ -237,14 +237,14 @@ class StackBidirectionalRNNTest(test.TestCase):
sequence_length = array_ops.placeholder(dtypes.int64)
self.cells_fw = [
- core_rnn_cell_impl.LSTMCell(
+ rnn_cell.LSTMCell(
num_units,
input_size,
initializer=initializer,
state_is_tuple=False) for num_units in self.layers
]
self.cells_bw = [
- core_rnn_cell_impl.LSTMCell(
+ rnn_cell.LSTMCell(
num_units,
input_size,
initializer=initializer,
@@ -258,7 +258,7 @@ class StackBidirectionalRNNTest(test.TestCase):
]
inputs_c = array_ops.stack(inputs)
inputs_c = array_ops.transpose(inputs_c, [1, 0, 2])
- outputs, st_fw, st_bw = rnn.stack_bidirectional_dynamic_rnn(
+ outputs, st_fw, st_bw = contrib_rnn.stack_bidirectional_dynamic_rnn(
self.cells_fw,
self.cells_bw,
inputs_c,
diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn.py b/tensorflow/contrib/rnn/python/ops/core_rnn.py
deleted file mode 100644
index 3ce075ce9c..0000000000
--- a/tensorflow/contrib/rnn/python/ops/core_rnn.py
+++ /dev/null
@@ -1,357 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""RNN helpers for TensorFlow models."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import rnn
-from tensorflow.python.ops import rnn_cell_impl
-from tensorflow.python.ops import variable_scope as vs
-from tensorflow.python.util import nest
-
-
-# pylint: disable=protected-access
-_concat = rnn_cell_impl._concat
-_like_rnncell = rnn_cell_impl._like_rnncell
-_infer_state_dtype = rnn._infer_state_dtype
-_reverse_seq = rnn._reverse_seq
-_rnn_step = rnn._rnn_step
-# pylint: enable=protected-access
-
-
-def static_rnn(cell, inputs, initial_state=None, dtype=None,
- sequence_length=None, scope=None):
- """Creates a recurrent neural network specified by RNNCell `cell`.
-
- The simplest form of RNN network generated is:
-
- ```python
- state = cell.zero_state(...)
- outputs = []
- for input_ in inputs:
- output, state = cell(input_, state)
- outputs.append(output)
- return (outputs, state)
- ```
- However, a few other options are available:
-
- An initial state can be provided.
- If the sequence_length vector is provided, dynamic calculation is performed.
- This method of calculation does not compute the RNN steps past the maximum
- sequence length of the minibatch (thus saving computational time),
- and properly propagates the state at an example's sequence length
- to the final state output.
-
- The dynamic calculation performed is, at time `t` for batch row `b`,
-
- ```python
- (output, state)(b, t) =
- (t >= sequence_length(b))
- ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
- : cell(input(b, t), state(b, t - 1))
- ```
-
- Args:
- cell: An instance of RNNCell.
- inputs: A length T list of inputs, each a `Tensor` of shape
- `[batch_size, input_size]`, or a nested tuple of such elements.
- initial_state: (optional) An initial state for the RNN.
- If `cell.state_size` is an integer, this must be
- a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
- If `cell.state_size` is a tuple, this should be a tuple of
- tensors having shapes `[batch_size, s] for s in cell.state_size`.
- dtype: (optional) The data type for the initial state and expected output.
- Required if initial_state is not provided or RNN state has a heterogeneous
- dtype.
- sequence_length: Specifies the length of each sequence in inputs.
- An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
- scope: VariableScope for the created subgraph; defaults to "rnn".
-
- Returns:
- A pair (outputs, state) where:
-
- - outputs is a length T list of outputs (one for each input), or a nested
- tuple of such elements.
- - state is the final state
-
- Raises:
- TypeError: If `cell` is not an instance of RNNCell.
- ValueError: If `inputs` is `None` or an empty list, or if the input depth
- (column size) cannot be inferred from inputs via shape inference.
- """
-
- if not _like_rnncell(cell):
- raise TypeError("cell must be an instance of RNNCell")
- if not nest.is_sequence(inputs):
- raise TypeError("inputs must be a sequence")
- if not inputs:
- raise ValueError("inputs must not be empty")
-
- outputs = []
- # Create a new scope in which the caching device is either
- # determined by the parent scope, or is set to place the cached
- # Variable using the same placement as for the rest of the RNN.
- with vs.variable_scope(scope or "rnn") as varscope:
- if varscope.caching_device is None:
- varscope.set_caching_device(lambda op: op.device)
-
- # Obtain the first sequence of the input
- first_input = inputs
- while nest.is_sequence(first_input):
- first_input = first_input[0]
-
- # Temporarily avoid EmbeddingWrapper and seq2seq badness
- # TODO(lukaszkaiser): remove EmbeddingWrapper
- if first_input.get_shape().ndims != 1:
-
- input_shape = first_input.get_shape().with_rank_at_least(2)
- fixed_batch_size = input_shape[0]
-
- flat_inputs = nest.flatten(inputs)
- for flat_input in flat_inputs:
- input_shape = flat_input.get_shape().with_rank_at_least(2)
- batch_size, input_size = input_shape[0], input_shape[1:]
- fixed_batch_size.merge_with(batch_size)
- for i, size in enumerate(input_size):
- if size.value is None:
- raise ValueError(
- "Input size (dimension %d of inputs) must be accessible via "
- "shape inference, but saw value None." % i)
- else:
- fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]
-
- if fixed_batch_size.value:
- batch_size = fixed_batch_size.value
- else:
- batch_size = array_ops.shape(first_input)[0]
- if initial_state is not None:
- state = initial_state
- else:
- if not dtype:
- raise ValueError("If no initial_state is provided, "
- "dtype must be specified")
- state = cell.zero_state(batch_size, dtype)
-
- if sequence_length is not None: # Prepare variables
- sequence_length = ops.convert_to_tensor(
- sequence_length, name="sequence_length")
- if sequence_length.get_shape().ndims not in (None, 1):
- raise ValueError(
- "sequence_length must be a vector of length batch_size")
- def _create_zero_output(output_size):
- # convert int to TensorShape if necessary
- size = _concat(batch_size, output_size)
- output = array_ops.zeros(
- array_ops.stack(size), _infer_state_dtype(dtype, state))
- shape = _concat(fixed_batch_size.value, output_size, static=True)
- output.set_shape(tensor_shape.TensorShape(shape))
- return output
-
- output_size = cell.output_size
- flat_output_size = nest.flatten(output_size)
- flat_zero_output = tuple(
- _create_zero_output(size) for size in flat_output_size)
- zero_output = nest.pack_sequence_as(structure=output_size,
- flat_sequence=flat_zero_output)
-
- sequence_length = math_ops.to_int32(sequence_length)
- min_sequence_length = math_ops.reduce_min(sequence_length)
- max_sequence_length = math_ops.reduce_max(sequence_length)
-
- for time, input_ in enumerate(inputs):
- if time > 0: varscope.reuse_variables()
- # pylint: disable=cell-var-from-loop
- call_cell = lambda: cell(input_, state)
- # pylint: enable=cell-var-from-loop
- if sequence_length is not None:
- (output, state) = _rnn_step(
- time=time,
- sequence_length=sequence_length,
- min_sequence_length=min_sequence_length,
- max_sequence_length=max_sequence_length,
- zero_output=zero_output,
- state=state,
- call_cell=call_cell,
- state_size=cell.state_size)
- else:
- (output, state) = call_cell()
-
- outputs.append(output)
-
- return (outputs, state)
-
-
-def static_state_saving_rnn(cell, inputs, state_saver, state_name,
- sequence_length=None, scope=None):
- """RNN that accepts a state saver for time-truncated RNN calculation.
-
- Args:
- cell: An instance of `RNNCell`.
- inputs: A length T list of inputs, each a `Tensor` of shape
- `[batch_size, input_size]`.
- state_saver: A state saver object with methods `state` and `save_state`.
- state_name: Python string or tuple of strings. The name to use with the
- state_saver. If the cell returns tuples of states (i.e.,
- `cell.state_size` is a tuple) then `state_name` should be a tuple of
- strings having the same length as `cell.state_size`. Otherwise it should
- be a single string.
- sequence_length: (optional) An int32/int64 vector size [batch_size].
- See the documentation for rnn() for more details about sequence_length.
- scope: VariableScope for the created subgraph; defaults to "rnn".
-
- Returns:
- A pair (outputs, state) where:
- outputs is a length T list of outputs (one for each input)
- states is the final state
-
- Raises:
- TypeError: If `cell` is not an instance of RNNCell.
- ValueError: If `inputs` is `None` or an empty list, or if the arity and
- type of `state_name` does not match that of `cell.state_size`.
- """
- state_size = cell.state_size
- state_is_tuple = nest.is_sequence(state_size)
- state_name_tuple = nest.is_sequence(state_name)
-
- if state_is_tuple != state_name_tuple:
- raise ValueError(
- "state_name should be the same type as cell.state_size. "
- "state_name: %s, cell.state_size: %s"
- % (str(state_name), str(state_size)))
-
- if state_is_tuple:
- state_name_flat = nest.flatten(state_name)
- state_size_flat = nest.flatten(state_size)
-
- if len(state_name_flat) != len(state_size_flat):
- raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d"
- % (len(state_name_flat), len(state_size_flat)))
-
- initial_state = nest.pack_sequence_as(
- structure=state_size,
- flat_sequence=[state_saver.state(s) for s in state_name_flat])
- else:
- initial_state = state_saver.state(state_name)
-
- (outputs, state) = static_rnn(cell, inputs, initial_state=initial_state,
- sequence_length=sequence_length, scope=scope)
-
- if state_is_tuple:
- flat_state = nest.flatten(state)
- state_name = nest.flatten(state_name)
- save_state = [state_saver.save_state(name, substate)
- for name, substate in zip(state_name, flat_state)]
- else:
- save_state = [state_saver.save_state(state_name, state)]
-
- with ops.control_dependencies(save_state):
- last_output = outputs[-1]
- flat_last_output = nest.flatten(last_output)
- flat_last_output = [
- array_ops.identity(output) for output in flat_last_output]
- outputs[-1] = nest.pack_sequence_as(structure=last_output,
- flat_sequence=flat_last_output)
-
- return (outputs, state)
-
-
-def static_bidirectional_rnn(cell_fw, cell_bw, inputs,
- initial_state_fw=None, initial_state_bw=None,
- dtype=None, sequence_length=None, scope=None):
- """Creates a bidirectional recurrent neural network.
-
- Similar to the unidirectional case above (rnn) but takes input and builds
- independent forward and backward RNNs with the final forward and backward
- outputs depth-concatenated, such that the output will have the format
- [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
- forward and backward cell must match. The initial state for both directions
- is zero by default (but can be set optionally) and no intermediate states are
- ever returned -- the network is fully unrolled for the given (passed in)
- length(s) of the sequence(s) or completely unrolled if length(s) is not given.
-
- Args:
- cell_fw: An instance of RNNCell, to be used for forward direction.
- cell_bw: An instance of RNNCell, to be used for backward direction.
- inputs: A length T list of inputs, each a tensor of shape
- [batch_size, input_size], or a nested tuple of such elements.
- initial_state_fw: (optional) An initial state for the forward RNN.
- This must be a tensor of appropriate type and shape
- `[batch_size, cell_fw.state_size]`.
- If `cell_fw.state_size` is a tuple, this should be a tuple of
- tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
- initial_state_bw: (optional) Same as for `initial_state_fw`, but using
- the corresponding properties of `cell_bw`.
- dtype: (optional) The data type for the initial state. Required if
- either of the initial states are not provided.
- sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
- containing the actual lengths for each of the sequences.
- scope: VariableScope for the created subgraph; defaults to
- "bidirectional_rnn"
-
- Returns:
- A tuple (outputs, output_state_fw, output_state_bw) where:
- outputs is a length `T` list of outputs (one for each input), which
- are depth-concatenated forward and backward outputs.
- output_state_fw is the final state of the forward rnn.
- output_state_bw is the final state of the backward rnn.
-
- Raises:
- TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
- ValueError: If inputs is None or an empty list.
- """
-
- if not _like_rnncell(cell_fw):
- raise TypeError("cell_fw must be an instance of RNNCell")
- if not _like_rnncell(cell_bw):
- raise TypeError("cell_bw must be an instance of RNNCell")
- if not nest.is_sequence(inputs):
- raise TypeError("inputs must be a sequence")
- if not inputs:
- raise ValueError("inputs must not be empty")
-
- with vs.variable_scope(scope or "bidirectional_rnn"):
- # Forward direction
- with vs.variable_scope("fw") as fw_scope:
- output_fw, output_state_fw = static_rnn(
- cell_fw, inputs, initial_state_fw, dtype,
- sequence_length, scope=fw_scope)
-
- # Backward direction
- with vs.variable_scope("bw") as bw_scope:
- reversed_inputs = _reverse_seq(inputs, sequence_length)
- tmp, output_state_bw = static_rnn(
- cell_bw, reversed_inputs, initial_state_bw,
- dtype, sequence_length, scope=bw_scope)
-
- output_bw = _reverse_seq(tmp, sequence_length)
- # Concat each of the forward/backward outputs
- flat_output_fw = nest.flatten(output_fw)
- flat_output_bw = nest.flatten(output_bw)
-
- flat_outputs = tuple(
- array_ops.concat([fw, bw], 1)
- for fw, bw in zip(flat_output_fw, flat_output_bw))
-
- outputs = nest.pack_sequence_as(structure=output_fw,
- flat_sequence=flat_outputs)
-
- return (outputs, output_state_fw, output_state_bw)
diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py
index c101b68d92..6b6bd503ce 100644
--- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py
@@ -12,45 +12,219 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+"""Module implementing RNN Cells that used to be in core.
-"""Module for constructing RNN Cells.
+@@EmbeddingWrapper
+@@InputProjectionWrapper
+@@OutputProjectionWrapper
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
-## Base interface for all RNN Cells
+import math
-@@RNNCell
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.platform import tf_logging as logging
-## RNN Cells for use with TensorFlow's core RNN methods
+RNNCell = rnn_cell_impl.RNNCell # pylint: disable=invalid-name
+_linear = rnn_cell_impl._linear # pylint: disable=invalid-name, protected-access
+_like_rnncell = rnn_cell_impl._like_rnncell # pylint: disable=invalid-name, protected-access
-@@BasicRNNCell
-@@BasicLSTMCell
-@@GRUCell
-@@LSTMCell
-## Classes storing split `RNNCell` state
+class EmbeddingWrapper(RNNCell):
+ """Operator adding input embedding to the given cell.
-@@LSTMStateTuple
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your inputs in time,
+ do the embedding on this batch-concatenated sequence, then split it and
+ feed into your RNN.
+ """
-## RNN Cell wrappers (RNNCells that wrap other RNNCells)
+ def __init__(self,
+ cell,
+ embedding_classes,
+ embedding_size,
+ initializer=None,
+ reuse=None):
+ """Create a cell with an added input embedding.
-@@MultiRNNCell
-@@DropoutWrapper
-@@EmbeddingWrapper
-@@InputProjectionWrapper
-@@OutputProjectionWrapper
-@@DeviceWrapper
-@@ResidualWrapper
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+ Args:
+ cell: an RNNCell, an embedding will be put before its inputs.
+ embedding_classes: integer, how many symbols will be embedded.
+ embedding_size: integer, the size of the vectors we embed into.
+ initializer: an initializer to use when creating the embedding;
+ if None, the initializer from variable scope or a default one is used.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if embedding_classes is not positive.
+ """
+ super(EmbeddingWrapper, self).__init__(_reuse=reuse)
+ if not _like_rnncell(cell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ if embedding_classes <= 0 or embedding_size <= 0:
+ raise ValueError("Both embedding_classes and embedding_size must be > 0: "
+ "%d, %d." % (embedding_classes, embedding_size))
+ self._cell = cell
+ self._embedding_classes = embedding_classes
+ self._embedding_size = embedding_size
+ self._initializer = initializer
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def zero_state(self, batch_size, dtype):
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
+ return self._cell.zero_state(batch_size, dtype)
+
+ def call(self, inputs, state):
+ """Run the cell on embedded inputs."""
+ with ops.device("/cpu:0"):
+ if self._initializer:
+ initializer = self._initializer
+ elif vs.get_variable_scope().initializer:
+ initializer = vs.get_variable_scope().initializer
+ else:
+ # Default initializer for embeddings should have variance=1.
+ sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
+ initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
+
+ if isinstance(state, tuple):
+ data_type = state[0].dtype
+ else:
+ data_type = state.dtype
+
+ embedding = vs.get_variable(
+ "embedding", [self._embedding_classes, self._embedding_size],
+ initializer=initializer,
+ dtype=data_type)
+ embedded = embedding_ops.embedding_lookup(embedding,
+ array_ops.reshape(inputs, [-1]))
+
+ return self._cell(embedded, state)
+
+
+class InputProjectionWrapper(RNNCell):
+ """Operator adding an input projection to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your inputs in time,
+ do the projection on this batch-concatenated sequence, then split it.
+ """
+
+ def __init__(self,
+ cell,
+ num_proj,
+ activation=None,
+ input_size=None,
+ reuse=None):
+ """Create a cell with input projection.
+
+ Args:
+ cell: an RNNCell, a projection of inputs is added before it.
+ num_proj: Python integer. The dimension to project to.
+ activation: (optional) an optional activation function.
+ input_size: Deprecated and unused.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ """
+ super(InputProjectionWrapper, self).__init__(_reuse=reuse)
+ if input_size is not None:
+ logging.warn("%s: The input_size parameter is deprecated.", self)
+ if not _like_rnncell(cell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ self._cell = cell
+ self._num_proj = num_proj
+ self._activation = activation
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def zero_state(self, batch_size, dtype):
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
+ return self._cell.zero_state(batch_size, dtype)
+
+ def call(self, inputs, state):
+ """Run the input projection and then the cell."""
+ # Default scope: "InputProjectionWrapper"
+ projected = _linear(inputs, self._num_proj, True)
+ if self._activation:
+ projected = self._activation(projected)
+ return self._cell(projected, state)
+
+
+class OutputProjectionWrapper(RNNCell):
+ """Operator adding an output projection to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your outputs in time,
+ do the projection on this batch-concatenated sequence, then split it
+ if needed or directly feed into a softmax.
+ """
+
+ def __init__(self, cell, output_size, activation=None, reuse=None):
+ """Create a cell with output projection.
+
+ Args:
+ cell: an RNNCell, a projection to output_size is added to it.
+ output_size: integer, the size of the output after projection.
+ activation: (optional) an optional activation function.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if output_size is not positive.
+ """
+ super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
+ if not _like_rnncell(cell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ if output_size < 1:
+ raise ValueError("Parameter output_size must be > 0: %d." % output_size)
+ self._cell = cell
+ self._output_size = output_size
+ self._activation = activation
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import *
-# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
+ @property
+ def state_size(self):
+ return self._cell.state_size
+ @property
+ def output_size(self):
+ return self._output_size
-_allowed_symbols = []
+ def zero_state(self, batch_size, dtype):
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
+ return self._cell.zero_state(batch_size, dtype)
-remove_undocumented(__name__, _allowed_symbols)
+ def call(self, inputs, state):
+ """Run the cell and output projection on inputs, starting from state."""
+ output, res_state = self._cell(inputs, state)
+ projected = _linear(output, self._output_size, True)
+ if self._activation:
+ projected = self._activation(projected)
+ return projected, res_state
diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
deleted file mode 100644
index e0616e0678..0000000000
--- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
+++ /dev/null
@@ -1,1048 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Module implementing RNN Cells.
-
-This module provides a number of basic commonly used RNN cells, such as LSTM
-(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of
-operators that allow adding dropouts, projections, or embeddings for inputs.
-Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by
-calling the `rnn` ops several times.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import collections
-import hashlib
-import math
-import numbers
-
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import clip_ops
-from tensorflow.python.ops import embedding_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import partitioned_variables
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import rnn_cell_impl
-from tensorflow.python.ops import variable_scope as vs
-
-from tensorflow.python.ops.math_ops import sigmoid
-from tensorflow.python.ops.math_ops import tanh
-
-from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.util import nest
-
-
-# pylint: disable=protected-access
-RNNCell = rnn_cell_impl._RNNCell # pylint: disable=invalid-name
-_like_rnncell = rnn_cell_impl._like_rnncell
-# pylint: enable=protected-access
-
-_BIAS_VARIABLE_NAME = "bias"
-_WEIGHTS_VARIABLE_NAME = "kernel"
-
-
-class BasicRNNCell(RNNCell):
- """The most basic RNN cell."""
-
- def __init__(self, num_units, input_size=None, activation=tanh, reuse=None):
- super(BasicRNNCell, self).__init__(_reuse=reuse)
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- self._num_units = num_units
- self._activation = activation
-
- @property
- def state_size(self):
- return self._num_units
-
- @property
- def output_size(self):
- return self._num_units
-
- def call(self, inputs, state):
- """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
- output = self._activation(_linear([inputs, state], self._num_units, True))
- return output, output
-
-
-class GRUCell(RNNCell):
- """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
-
- def __init__(self,
- num_units,
- input_size=None,
- activation=tanh,
- reuse=None,
- kernel_initializer=None,
- bias_initializer=None):
- super(GRUCell, self).__init__(_reuse=reuse)
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- self._num_units = num_units
- self._activation = activation
- self._kernel_initializer = kernel_initializer
- self._bias_initializer = bias_initializer
-
- @property
- def state_size(self):
- return self._num_units
-
- @property
- def output_size(self):
- return self._num_units
-
- def call(self, inputs, state):
- """Gated recurrent unit (GRU) with nunits cells."""
- with vs.variable_scope("gates"): # Reset gate and update gate.
- # We start with bias of 1.0 to not reset and not update.
- bias_ones = self._bias_initializer
- if self._bias_initializer is None:
- dtype = [a.dtype for a in [inputs, state]][0]
- bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
- value = sigmoid(
- _linear([inputs, state], 2 * self._num_units, True, bias_ones,
- self._kernel_initializer))
- r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
- with vs.variable_scope("candidate"):
- c = self._activation(
- _linear([inputs, r * state], self._num_units, True,
- self._bias_initializer, self._kernel_initializer))
- new_h = u * state + (1 - u) * c
- return new_h, new_h
-
-
-_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
-
-
-class LSTMStateTuple(_LSTMStateTuple):
- """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
-
- Stores two elements: `(c, h)`, in that order.
-
- Only used when `state_is_tuple=True`.
- """
- __slots__ = ()
-
- @property
- def dtype(self):
- (c, h) = self
- if not c.dtype == h.dtype:
- raise TypeError("Inconsistent internal state: %s vs %s" %
- (str(c.dtype), str(h.dtype)))
- return c.dtype
-
-
-class BasicLSTMCell(RNNCell):
- """Basic LSTM recurrent network cell.
-
- The implementation is based on: http://arxiv.org/abs/1409.2329.
-
- We add forget_bias (default: 1) to the biases of the forget gate in order to
- reduce the scale of forgetting in the beginning of the training.
-
- It does not allow cell clipping, a projection layer, and does not
- use peep-hole connections: it is the basic baseline.
-
- For advanced models, please use the full LSTMCell that follows.
- """
-
- def __init__(self, num_units, forget_bias=1.0, input_size=None,
- state_is_tuple=True, activation=tanh, reuse=None):
- """Initialize the basic LSTM cell.
-
- Args:
- num_units: int, The number of units in the LSTM cell.
- forget_bias: float, The bias added to forget gates (see above).
- input_size: Deprecated and unused.
- state_is_tuple: If True, accepted and returned states are 2-tuples of
- the `c_state` and `m_state`. If False, they are concatenated
- along the column axis. The latter behavior will soon be deprecated.
- activation: Activation function of the inner states.
- reuse: (optional) Python boolean describing whether to reuse variables
- in an existing scope. If not `True`, and the existing scope already has
- the given variables, an error is raised.
- """
- super(BasicLSTMCell, self).__init__(_reuse=reuse)
- if not state_is_tuple:
- logging.warn("%s: Using a concatenated state is slower and will soon be "
- "deprecated. Use state_is_tuple=True.", self)
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- self._num_units = num_units
- self._forget_bias = forget_bias
- self._state_is_tuple = state_is_tuple
- self._activation = activation
-
- @property
- def state_size(self):
- return (LSTMStateTuple(self._num_units, self._num_units)
- if self._state_is_tuple else 2 * self._num_units)
-
- @property
- def output_size(self):
- return self._num_units
-
- def call(self, inputs, state):
- """Long short-term memory cell (LSTM)."""
- # Parameters of gates are concatenated into one multiply for efficiency.
- if self._state_is_tuple:
- c, h = state
- else:
- c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
-
- concat = _linear([inputs, h], 4 * self._num_units, True)
-
- # i = input_gate, j = new_input, f = forget_gate, o = output_gate
- i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
-
- new_c = (
- c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
- new_h = self._activation(new_c) * sigmoid(o)
-
- if self._state_is_tuple:
- new_state = LSTMStateTuple(new_c, new_h)
- else:
- new_state = array_ops.concat([new_c, new_h], 1)
- return new_h, new_state
-
-
-class LSTMCell(RNNCell):
- """Long short-term memory unit (LSTM) recurrent network cell.
-
- The default non-peephole implementation is based on:
-
- http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
-
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
-
- The peephole implementation is based on:
-
- https://research.google.com/pubs/archive/43905.pdf
-
- Hasim Sak, Andrew Senior, and Francoise Beaufays.
- "Long short-term memory recurrent neural network architectures for
- large scale acoustic modeling." INTERSPEECH, 2014.
-
- The class uses optional peep-hole connections, optional cell clipping, and
- an optional projection layer.
- """
-
- def __init__(self, num_units, input_size=None,
- use_peepholes=False, cell_clip=None,
- initializer=None, num_proj=None, proj_clip=None,
- num_unit_shards=None, num_proj_shards=None,
- forget_bias=1.0, state_is_tuple=True,
- activation=tanh, reuse=None):
- """Initialize the parameters for an LSTM cell.
-
- Args:
- num_units: int, The number of units in the LSTM cell
- input_size: Deprecated and unused.
- use_peepholes: bool, set True to enable diagonal/peephole connections.
- cell_clip: (optional) A float value, if provided the cell state is clipped
- by this value prior to the cell output activation.
- initializer: (optional) The initializer to use for the weight and
- projection matrices.
- num_proj: (optional) int, The output dimensionality for the projection
- matrices. If None, no projection is performed.
- proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
- provided, then the projected values are clipped elementwise to within
- `[-proj_clip, proj_clip]`.
- num_unit_shards: Deprecated, will be removed by Jan. 2017.
- Use a variable_scope partitioner instead.
- num_proj_shards: Deprecated, will be removed by Jan. 2017.
- Use a variable_scope partitioner instead.
- forget_bias: Biases of the forget gate are initialized by default to 1
- in order to reduce the scale of forgetting at the beginning of
- the training.
- state_is_tuple: If True, accepted and returned states are 2-tuples of
- the `c_state` and `m_state`. If False, they are concatenated
- along the column axis. This latter behavior will soon be deprecated.
- activation: Activation function of the inner states.
- reuse: (optional) Python boolean describing whether to reuse variables
- in an existing scope. If not `True`, and the existing scope already has
- the given variables, an error is raised.
- """
- super(LSTMCell, self).__init__(_reuse=reuse)
- if not state_is_tuple:
- logging.warn("%s: Using a concatenated state is slower and will soon be "
- "deprecated. Use state_is_tuple=True.", self)
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- if num_unit_shards is not None or num_proj_shards is not None:
- logging.warn(
- "%s: The num_unit_shards and proj_unit_shards parameters are "
- "deprecated and will be removed in Jan 2017. "
- "Use a variable scope with a partitioner instead.", self)
-
- self._num_units = num_units
- self._use_peepholes = use_peepholes
- self._cell_clip = cell_clip
- self._initializer = initializer
- self._num_proj = num_proj
- self._proj_clip = proj_clip
- self._num_unit_shards = num_unit_shards
- self._num_proj_shards = num_proj_shards
- self._forget_bias = forget_bias
- self._state_is_tuple = state_is_tuple
- self._activation = activation
-
- if num_proj:
- self._state_size = (
- LSTMStateTuple(num_units, num_proj)
- if state_is_tuple else num_units + num_proj)
- self._output_size = num_proj
- else:
- self._state_size = (
- LSTMStateTuple(num_units, num_units)
- if state_is_tuple else 2 * num_units)
- self._output_size = num_units
-
- @property
- def state_size(self):
- return self._state_size
-
- @property
- def output_size(self):
- return self._output_size
-
- def call(self, inputs, state):
- """Run one step of LSTM.
-
- Args:
- inputs: input Tensor, 2D, batch x num_units.
- state: if `state_is_tuple` is False, this must be a state Tensor,
- `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
- tuple of state Tensors, both `2-D`, with column sizes `c_state` and
- `m_state`.
-
- Returns:
- A tuple containing:
-
- - A `2-D, [batch x output_dim]`, Tensor representing the output of the
- LSTM after reading `inputs` when previous state was `state`.
- Here output_dim is:
- num_proj if num_proj was set,
- num_units otherwise.
- - Tensor(s) representing the new state of LSTM after reading `inputs` when
- the previous state was `state`. Same type and shape(s) as `state`.
-
- Raises:
- ValueError: If input size cannot be inferred from inputs via
- static shape inference.
- """
- num_proj = self._num_units if self._num_proj is None else self._num_proj
-
- if self._state_is_tuple:
- (c_prev, m_prev) = state
- else:
- c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
- m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
-
- dtype = inputs.dtype
- input_size = inputs.get_shape().with_rank(2)[1]
- if input_size.value is None:
- raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
- scope = vs.get_variable_scope()
- with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
- if self._num_unit_shards is not None:
- unit_scope.set_partitioner(
- partitioned_variables.fixed_size_partitioner(
- self._num_unit_shards))
- # i = input_gate, j = new_input, f = forget_gate, o = output_gate
- lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True)
- i, j, f, o = array_ops.split(
- value=lstm_matrix, num_or_size_splits=4, axis=1)
- # Diagonal connections
- if self._use_peepholes:
- with vs.variable_scope(unit_scope) as projection_scope:
- if self._num_unit_shards is not None:
- projection_scope.set_partitioner(None)
- w_f_diag = vs.get_variable(
- "w_f_diag", shape=[self._num_units], dtype=dtype)
- w_i_diag = vs.get_variable(
- "w_i_diag", shape=[self._num_units], dtype=dtype)
- w_o_diag = vs.get_variable(
- "w_o_diag", shape=[self._num_units], dtype=dtype)
-
- if self._use_peepholes:
- c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
- sigmoid(i + w_i_diag * c_prev) * self._activation(j))
- else:
- c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
- self._activation(j))
-
- if self._cell_clip is not None:
- # pylint: disable=invalid-unary-operand-type
- c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
- # pylint: enable=invalid-unary-operand-type
- if self._use_peepholes:
- m = sigmoid(o + w_o_diag * c) * self._activation(c)
- else:
- m = sigmoid(o) * self._activation(c)
-
- if self._num_proj is not None:
- with vs.variable_scope("projection") as proj_scope:
- if self._num_proj_shards is not None:
- proj_scope.set_partitioner(
- partitioned_variables.fixed_size_partitioner(
- self._num_proj_shards))
- m = _linear(m, self._num_proj, bias=False)
-
- if self._proj_clip is not None:
- # pylint: disable=invalid-unary-operand-type
- m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
- # pylint: enable=invalid-unary-operand-type
-
- new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
- array_ops.concat([c, m], 1))
- return m, new_state
-
-
-class OutputProjectionWrapper(RNNCell):
- """Operator adding an output projection to the given cell.
-
- Note: in many cases it may be more efficient to not use this wrapper,
- but instead concatenate the whole sequence of your outputs in time,
- do the projection on this batch-concatenated sequence, then split it
- if needed or directly feed into a softmax.
- """
-
- def __init__(self, cell, output_size, activation=None, reuse=None):
- """Create a cell with output projection.
-
- Args:
- cell: an RNNCell, a projection to output_size is added to it.
- output_size: integer, the size of the output after projection.
- activation: (optional) an optional activation function.
- reuse: (optional) Python boolean describing whether to reuse variables
- in an existing scope. If not `True`, and the existing scope already has
- the given variables, an error is raised.
-
- Raises:
- TypeError: if cell is not an RNNCell.
- ValueError: if output_size is not positive.
- """
- super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not RNNCell.")
- if output_size < 1:
- raise ValueError("Parameter output_size must be > 0: %d." % output_size)
- self._cell = cell
- self._output_size = output_size
- self._activation = activation
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._output_size
-
- def zero_state(self, batch_size, dtype):
- with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
- return self._cell.zero_state(batch_size, dtype)
-
- def call(self, inputs, state):
- """Run the cell and output projection on inputs, starting from state."""
- output, res_state = self._cell(inputs, state)
- projected = _linear(output, self._output_size, True)
- if self._activation:
- projected = self._activation(projected)
- return projected, res_state
-
-
-class InputProjectionWrapper(RNNCell):
- """Operator adding an input projection to the given cell.
-
- Note: in many cases it may be more efficient to not use this wrapper,
- but instead concatenate the whole sequence of your inputs in time,
- do the projection on this batch-concatenated sequence, then split it.
- """
-
- def __init__(self, cell, num_proj, activation=None, input_size=None,
- reuse=None):
- """Create a cell with input projection.
-
- Args:
- cell: an RNNCell, a projection of inputs is added before it.
- num_proj: Python integer. The dimension to project to.
- activation: (optional) an optional activation function.
- input_size: Deprecated and unused.
- reuse: (optional) Python boolean describing whether to reuse variables
- in an existing scope. If not `True`, and the existing scope already has
- the given variables, an error is raised.
-
- Raises:
- TypeError: if cell is not an RNNCell.
- """
- super(InputProjectionWrapper, self).__init__(_reuse=reuse)
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not RNNCell.")
- self._cell = cell
- self._num_proj = num_proj
- self._activation = activation
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._cell.output_size
-
- def zero_state(self, batch_size, dtype):
- with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
- return self._cell.zero_state(batch_size, dtype)
-
- def call(self, inputs, state):
- """Run the input projection and then the cell."""
- # Default scope: "InputProjectionWrapper"
- projected = _linear(inputs, self._num_proj, True)
- if self._activation:
- projected = self._activation(projected)
- return self._cell(projected, state)
-
-
-def _enumerated_map_structure(map_fn, *args, **kwargs):
- ix = [0]
- def enumerated_fn(*inner_args, **inner_kwargs):
- r = map_fn(ix[0], *inner_args, **inner_kwargs)
- ix[0] += 1
- return r
- return nest.map_structure(enumerated_fn, *args, **kwargs)
-
-
-class DropoutWrapper(RNNCell):
- """Operator adding dropout to inputs and outputs of the given cell."""
-
- def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
- state_keep_prob=1.0, variational_recurrent=False,
- input_size=None, dtype=None, seed=None):
- """Create a cell with added input, state, and/or output dropout.
-
- If `variational_recurrent` is set to `True` (**NOT** the default behavior),
- then the the same dropout mask is applied at every step, as described in:
-
- Y. Gal, Z Ghahramani. "A Theoretically Grounded Application of Dropout in
- Recurrent Neural Networks". https://arxiv.org/abs/1512.05287
-
- Otherwise a different dropout mask is applied at every time step.
-
- Args:
- cell: an RNNCell, a projection to output_size is added to it.
- input_keep_prob: unit Tensor or float between 0 and 1, input keep
- probability; if it is constant and 1, no input dropout will be added.
- output_keep_prob: unit Tensor or float between 0 and 1, output keep
- probability; if it is constant and 1, no output dropout will be added.
- state_keep_prob: unit Tensor or float between 0 and 1, output keep
- probability; if it is constant and 1, no output dropout will be added.
- State dropout is performed on the *output* states of the cell.
- variational_recurrent: Python bool. If `True`, then the same
- dropout pattern is applied across all time steps per run call.
- If this parameter is set, `input_size` **must** be provided.
- input_size: (optional) (possibly nested tuple of) `TensorShape` objects
- containing the depth(s) of the input tensors expected to be passed in to
- the `DropoutWrapper`. Required and used **iff**
- `variational_recurrent = True` and `input_keep_prob < 1`.
- dtype: (optional) The `dtype` of the input, state, and output tensors.
- Required and used **iff** `variational_recurrent = True`.
- seed: (optional) integer, the randomness seed.
-
- Raises:
- TypeError: if cell is not an RNNCell.
- ValueError: if any of the keep_probs are not between 0 and 1.
- """
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not a RNNCell.")
- with ops.name_scope("DropoutWrapperInit"):
- def tensor_and_const_value(v):
- tensor_value = ops.convert_to_tensor(v)
- const_value = tensor_util.constant_value(tensor_value)
- return (tensor_value, const_value)
- for prob, attr in [(input_keep_prob, "input_keep_prob"),
- (state_keep_prob, "state_keep_prob"),
- (output_keep_prob, "output_keep_prob")]:
- tensor_prob, const_prob = tensor_and_const_value(prob)
- if const_prob is not None:
- if const_prob < 0 or const_prob > 1:
- raise ValueError("Parameter %s must be between 0 and 1: %d"
- % (attr, const_prob))
- setattr(self, "_%s" % attr, float(const_prob))
- else:
- setattr(self, "_%s" % attr, tensor_prob)
-
- # Set cell, variational_recurrent, seed before running the code below
- self._cell = cell
- self._variational_recurrent = variational_recurrent
- self._seed = seed
-
- self._recurrent_input_noise = None
- self._recurrent_state_noise = None
- self._recurrent_output_noise = None
-
- if variational_recurrent:
- if dtype is None:
- raise ValueError(
- "When variational_recurrent=True, dtype must be provided")
-
- def convert_to_batch_shape(s):
- # Prepend a 1 for the batch dimension; for recurrent
- # variational dropout we use the same dropout mask for all
- # batch elements.
- return array_ops.concat(
- ([1], tensor_shape.TensorShape(s).as_list()), 0)
-
- def batch_noise(s, inner_seed):
- shape = convert_to_batch_shape(s)
- return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype)
-
- if (not isinstance(self._input_keep_prob, numbers.Real) or
- self._input_keep_prob < 1.0):
- if input_size is None:
- raise ValueError(
- "When variational_recurrent=True and input_keep_prob < 1.0 or "
- "is unknown, input_size must be provided")
- self._recurrent_input_noise = _enumerated_map_structure(
- lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)),
- input_size)
- self._recurrent_state_noise = _enumerated_map_structure(
- lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)),
- cell.state_size)
- self._recurrent_output_noise = _enumerated_map_structure(
- lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)),
- cell.output_size)
-
- def _gen_seed(self, salt_prefix, index):
- if self._seed is None:
- return None
- salt = "%s_%d" % (salt_prefix, index)
- string = (str(self._seed) + salt).encode("utf-8")
- return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._cell.output_size
-
- def zero_state(self, batch_size, dtype):
- with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
- return self._cell.zero_state(batch_size, dtype)
-
- def _variational_recurrent_dropout_value(
- self, index, value, noise, keep_prob):
- """Performs dropout given the pre-calculated noise tensor."""
- # uniform [keep_prob, 1.0 + keep_prob)
- random_tensor = keep_prob + noise
-
- # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
- binary_tensor = math_ops.floor(random_tensor)
- ret = math_ops.div(value, keep_prob) * binary_tensor
- ret.set_shape(value.get_shape())
- return ret
-
- def _dropout(self, values, salt_prefix, recurrent_noise, keep_prob):
- """Decides whether to perform standard dropout or recurrent dropout."""
- if not self._variational_recurrent:
- def dropout(i, v):
- return nn_ops.dropout(
- v, keep_prob=keep_prob, seed=self._gen_seed(salt_prefix, i))
- return _enumerated_map_structure(dropout, values)
- else:
- def dropout(i, v, n):
- return self._variational_recurrent_dropout_value(i, v, n, keep_prob)
- return _enumerated_map_structure(dropout, values, recurrent_noise)
-
- def __call__(self, inputs, state, scope=None):
- """Run the cell with the declared dropouts."""
- def _should_dropout(p):
- return (not isinstance(p, float)) or p < 1
-
- if _should_dropout(self._input_keep_prob):
- inputs = self._dropout(inputs, "input",
- self._recurrent_input_noise,
- self._input_keep_prob)
- output, new_state = self._cell(inputs, state, scope)
- if _should_dropout(self._state_keep_prob):
- new_state = self._dropout(new_state, "state",
- self._recurrent_state_noise,
- self._state_keep_prob)
- if _should_dropout(self._output_keep_prob):
- output = self._dropout(output, "output",
- self._recurrent_output_noise,
- self._output_keep_prob)
- return output, new_state
-
-
-class ResidualWrapper(RNNCell):
- """RNNCell wrapper that ensures cell inputs are added to the outputs."""
-
- def __init__(self, cell):
- """Constructs a `ResidualWrapper` for `cell`.
-
- Args:
- cell: An instance of `RNNCell`.
- """
- self._cell = cell
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._cell.output_size
-
- def zero_state(self, batch_size, dtype):
- with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
- return self._cell.zero_state(batch_size, dtype)
-
- def __call__(self, inputs, state, scope=None):
- """Run the cell and add its inputs to its outputs.
-
- Args:
- inputs: cell inputs.
- state: cell state.
- scope: optional cell scope.
-
- Returns:
- Tuple of cell outputs and new state.
-
- Raises:
- TypeError: If cell inputs and outputs have different structure (type).
- ValueError: If cell inputs and outputs have different structure (value).
- """
- outputs, new_state = self._cell(inputs, state, scope=scope)
- nest.assert_same_structure(inputs, outputs)
- # Ensure shapes match
- def assert_shape_match(inp, out):
- inp.get_shape().assert_is_compatible_with(out.get_shape())
- nest.map_structure(assert_shape_match, inputs, outputs)
- res_outputs = nest.map_structure(
- lambda inp, out: inp + out, inputs, outputs)
- return (res_outputs, new_state)
-
-
-class DeviceWrapper(RNNCell):
- """Operator that ensures an RNNCell runs on a particular device."""
-
- def __init__(self, cell, device):
- """Construct a `DeviceWrapper` for `cell` with device `device`.
-
- Ensures the wrapped `cell` is called with `tf.device(device)`.
-
- Args:
- cell: An instance of `RNNCell`.
- device: A device string or function, for passing to `tf.device`.
- """
- self._cell = cell
- self._device = device
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._cell.output_size
-
- def zero_state(self, batch_size, dtype):
- with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
- with ops.device(self._device):
- return self._cell.zero_state(batch_size, dtype)
-
- def __call__(self, inputs, state, scope=None):
- """Run the cell on specified device."""
- with ops.device(self._device):
- return self._cell(inputs, state, scope=scope)
-
-
-class EmbeddingWrapper(RNNCell):
- """Operator adding input embedding to the given cell.
-
- Note: in many cases it may be more efficient to not use this wrapper,
- but instead concatenate the whole sequence of your inputs in time,
- do the embedding on this batch-concatenated sequence, then split it and
- feed into your RNN.
- """
-
- def __init__(self, cell, embedding_classes, embedding_size, initializer=None,
- reuse=None):
- """Create a cell with an added input embedding.
-
- Args:
- cell: an RNNCell, an embedding will be put before its inputs.
- embedding_classes: integer, how many symbols will be embedded.
- embedding_size: integer, the size of the vectors we embed into.
- initializer: an initializer to use when creating the embedding;
- if None, the initializer from variable scope or a default one is used.
- reuse: (optional) Python boolean describing whether to reuse variables
- in an existing scope. If not `True`, and the existing scope already has
- the given variables, an error is raised.
-
- Raises:
- TypeError: if cell is not an RNNCell.
- ValueError: if embedding_classes is not positive.
- """
- super(EmbeddingWrapper, self).__init__(_reuse=reuse)
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not RNNCell.")
- if embedding_classes <= 0 or embedding_size <= 0:
- raise ValueError("Both embedding_classes and embedding_size must be > 0: "
- "%d, %d." % (embedding_classes, embedding_size))
- self._cell = cell
- self._embedding_classes = embedding_classes
- self._embedding_size = embedding_size
- self._initializer = initializer
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._cell.output_size
-
- def zero_state(self, batch_size, dtype):
- with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
- return self._cell.zero_state(batch_size, dtype)
-
- def call(self, inputs, state):
- """Run the cell on embedded inputs."""
- with ops.device("/cpu:0"):
- if self._initializer:
- initializer = self._initializer
- elif vs.get_variable_scope().initializer:
- initializer = vs.get_variable_scope().initializer
- else:
- # Default initializer for embeddings should have variance=1.
- sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
- initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
-
- if type(state) is tuple:
- data_type = state[0].dtype
- else:
- data_type = state.dtype
-
- embedding = vs.get_variable(
- "embedding", [self._embedding_classes, self._embedding_size],
- initializer=initializer,
- dtype=data_type)
- embedded = embedding_ops.embedding_lookup(embedding,
- array_ops.reshape(inputs, [-1]))
-
- return self._cell(embedded, state)
-
-
-class MultiRNNCell(RNNCell):
- """RNN cell composed sequentially of multiple simple cells."""
-
- def __init__(self, cells, state_is_tuple=True):
- """Create a RNN cell composed sequentially of a number of RNNCells.
-
- Args:
- cells: list of RNNCells that will be composed in this order.
- state_is_tuple: If True, accepted and returned states are n-tuples, where
- `n = len(cells)`. If False, the states are all
- concatenated along the column axis. This latter behavior will soon be
- deprecated.
-
- Raises:
- ValueError: if cells is empty (not allowed), or at least one of the cells
- returns a state tuple but the flag `state_is_tuple` is `False`.
- """
- super(MultiRNNCell, self).__init__()
- if not cells:
- raise ValueError("Must specify at least one cell for MultiRNNCell.")
- if not nest.is_sequence(cells):
- raise TypeError(
- "cells must be a list or tuple, but saw: %s." % cells)
-
- self._cells = cells
- self._state_is_tuple = state_is_tuple
- if not state_is_tuple:
- if any(nest.is_sequence(c.state_size) for c in self._cells):
- raise ValueError("Some cells return tuples of states, but the flag "
- "state_is_tuple is not set. State sizes are: %s"
- % str([c.state_size for c in self._cells]))
-
- @property
- def state_size(self):
- if self._state_is_tuple:
- return tuple(cell.state_size for cell in self._cells)
- else:
- return sum([cell.state_size for cell in self._cells])
-
- @property
- def output_size(self):
- return self._cells[-1].output_size
-
- def zero_state(self, batch_size, dtype):
- with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
- if self._state_is_tuple:
- return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
- else:
- # We know here that state_size of each cell is not a tuple and
- # presumably does not contain TensorArrays or anything else fancy
- return super(MultiRNNCell, self).zero_state(batch_size, dtype)
-
- def call(self, inputs, state):
- """Run this multi-layer cell on inputs, starting from state."""
- cur_state_pos = 0
- cur_inp = inputs
- new_states = []
- for i, cell in enumerate(self._cells):
- with vs.variable_scope("cell_%d" % i):
- if self._state_is_tuple:
- if not nest.is_sequence(state):
- raise ValueError(
- "Expected state to be a tuple of length %d, but received: %s" %
- (len(self.state_size), state))
- cur_state = state[i]
- else:
- cur_state = array_ops.slice(state, [0, cur_state_pos],
- [-1, cell.state_size])
- cur_state_pos += cell.state_size
- cur_inp, new_state = cell(cur_inp, cur_state)
- new_states.append(new_state)
-
- new_states = (tuple(new_states) if self._state_is_tuple else
- array_ops.concat(new_states, 1))
-
- return cur_inp, new_states
-
-
-class _SlimRNNCell(RNNCell):
- """A simple wrapper for slim.rnn_cells."""
-
- def __init__(self, cell_fn):
- """Create a SlimRNNCell from a cell_fn.
-
- Args:
- cell_fn: a function which takes (inputs, state, scope) and produces the
- outputs and the new_state. Additionally when called with inputs=None and
- state=None it should return (initial_outputs, initial_state).
-
- Raises:
- TypeError: if cell_fn is not callable
- ValueError: if cell_fn cannot produce a valid initial state.
- """
- if not callable(cell_fn):
- raise TypeError("cell_fn %s needs to be callable", cell_fn)
- self._cell_fn = cell_fn
- self._cell_name = cell_fn.func.__name__
- init_output, init_state = self._cell_fn(None, None)
- output_shape = init_output.get_shape()
- state_shape = init_state.get_shape()
- self._output_size = output_shape.with_rank(2)[1].value
- self._state_size = state_shape.with_rank(2)[1].value
- if self._output_size is None:
- raise ValueError("Initial output created by %s has invalid shape %s" %
- (self._cell_name, output_shape))
- if self._state_size is None:
- raise ValueError("Initial state created by %s has invalid shape %s" %
- (self._cell_name, state_shape))
-
- @property
- def state_size(self):
- return self._state_size
-
- @property
- def output_size(self):
- return self._output_size
-
- def __call__(self, inputs, state, scope=None):
- scope = scope or self._cell_name
- output, state = self._cell_fn(inputs, state, scope=scope)
- return output, state
-
-
-def _linear(args,
- output_size,
- bias,
- bias_initializer=None,
- kernel_initializer=None):
- """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
-
- Args:
- args: a 2D Tensor or a list of 2D, batch x n, Tensors.
- output_size: int, second dimension of W[i].
- bias: boolean, whether to add a bias term or not.
- bias_initializer: starting value to initialize the bias; None by default.
- kernel_initializer: starting value to initialize the weight; None by default.
-
- Returns:
- A 2D Tensor with shape [batch x output_size] equal to
- sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
-
- Raises:
- ValueError: if some of the arguments has unspecified or wrong shape.
- """
- if args is None or (nest.is_sequence(args) and not args):
- raise ValueError("`args` must be specified")
- if not nest.is_sequence(args):
- args = [args]
-
- # Calculate the total size of arguments on dimension 1.
- total_arg_size = 0
- shapes = [a.get_shape() for a in args]
- for shape in shapes:
- if shape.ndims != 2:
- raise ValueError("linear is expecting 2D arguments: %s" % shapes)
- if shape[1].value is None:
- raise ValueError("linear expects shape[1] to be provided for shape %s, "
- "but saw %s" % (shape, shape[1]))
- else:
- total_arg_size += shape[1].value
-
- dtype = [a.dtype for a in args][0]
-
- # Now the computation.
- scope = vs.get_variable_scope()
- with vs.variable_scope(scope) as outer_scope:
- weights = vs.get_variable(
- _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size],
- dtype=dtype,
- initializer=kernel_initializer)
- if len(args) == 1:
- res = math_ops.matmul(args[0], weights)
- else:
- res = math_ops.matmul(array_ops.concat(args, 1), weights)
- if not bias:
- return res
- with vs.variable_scope(outer_scope) as inner_scope:
- inner_scope.set_partitioner(None)
- if bias_initializer is None:
- bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
- biases = vs.get_variable(
- _BIAS_VARIABLE_NAME, [output_size],
- dtype=dtype,
- initializer=bias_initializer)
- return nn_ops.bias_add(res, biases)
diff --git a/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py b/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py
index 65e8705d1e..b7393d8b98 100644
--- a/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import abc
-from tensorflow.contrib.rnn.python.ops import core_rnn as contrib_rnn
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import rnn
@@ -116,12 +115,13 @@ class FusedRNNCellAdaptor(FusedRNNCell):
else: # non-dynamic rnn
if not is_list:
inputs = array_ops.unstack(inputs)
- outputs, state = contrib_rnn.static_rnn(self._cell,
- inputs,
- initial_state=initial_state,
- dtype=dtype,
- sequence_length=sequence_length,
- scope=scope)
+ outputs, state = rnn.static_rnn(
+ self._cell,
+ inputs,
+ initial_state=initial_state,
+ dtype=dtype,
+ sequence_length=sequence_length,
+ scope=scope)
if not is_list:
# Convert outputs back to tensor
outputs = array_ops.stack(outputs)
diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py
index 071a570e2d..de57e7d81e 100644
--- a/tensorflow/contrib/rnn/python/ops/gru_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py
@@ -18,13 +18,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.rnn.ops import gen_gru_ops
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import resource_loader
@@ -94,7 +94,7 @@ def _GRUBlockCellGrad(op, *grad):
return d_x, d_h_prev, d_w_ru, d_w_c, d_b_ru, d_b_c
-class GRUBlockCell(core_rnn_cell.RNNCell):
+class GRUBlockCell(rnn_cell_impl.RNNCell):
r"""Block GRU cell implementation.
The implementation is based on: http://arxiv.org/abs/1406.1078
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 0e70939cce..c41b5793fc 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import abc
from tensorflow.contrib.rnn.ops import gen_lstm_ops
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.contrib.rnn.python.ops import fused_rnn_cell
from tensorflow.contrib.util import loader
from tensorflow.python.framework import dtypes
@@ -29,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import resource_loader
@@ -325,7 +325,7 @@ def _BlockLSTMGrad(op, *grad):
wcf_grad, b_grad]
-class LSTMBlockCell(core_rnn_cell.RNNCell):
+class LSTMBlockCell(rnn_cell_impl.RNNCell):
"""Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/abs/1409.2329.
@@ -333,7 +333,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
We add `forget_bias` (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
- Unlike `core_rnn_cell.LSTMCell`, this is a monolithic op and should be much
+ Unlike `rnn_cell_impl.LSTMCell`, this is a monolithic op and should be much
faster. The weight and bias matrices should be compatible as long as the
variable scope matches.
"""
@@ -363,7 +363,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
@property
def state_size(self):
- return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
+ return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
@property
def output_size(self):
@@ -402,7 +402,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
forget_bias=self._forget_bias,
use_peephole=self._use_peephole)
- new_state = core_rnn_cell.LSTMStateTuple(cs, h)
+ new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
return h, new_state
@@ -546,8 +546,7 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell):
# Input was a list, so return a list
outputs = array_ops.unstack(outputs)
- final_state = core_rnn_cell.LSTMStateTuple(final_cell_state,
- final_output)
+ final_state = rnn_cell_impl.LSTMStateTuple(final_cell_state, final_output)
return outputs, final_state
def _gather_states(self, data, indices, batch_size):
@@ -569,7 +568,7 @@ class LSTMBlockFusedCell(LSTMBlockWrapper):
We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
- The variable naming is consistent with `core_rnn_cell.LSTMCell`.
+ The variable naming is consistent with `rnn_cell_impl.LSTMCell`.
"""
def __init__(self,
diff --git a/tensorflow/contrib/rnn/python/ops/rnn.py b/tensorflow/contrib/rnn/python/ops/rnn.py
index 1757a2148c..676441b4fc 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.rnn.python.ops import core_rnn as contrib_rnn
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import variable_scope as vs
@@ -106,7 +105,7 @@ def stack_bidirectional_rnn(cells_fw,
initial_state_bw = initial_states_bw[i]
with vs.variable_scope("cell_%d" % i) as cell_scope:
- prev_layer, state_fw, state_bw = contrib_rnn.static_bidirectional_rnn(
+ prev_layer, state_fw, state_bw = rnn.static_bidirectional_rnn(
cell_fw,
cell_bw,
prev_layer,
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 0898e78837..3dc8abb8b8 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -23,8 +23,6 @@ import math
from tensorflow.contrib.compiler import jit
from tensorflow.contrib.layers.python.layers import layers
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
@@ -76,7 +74,7 @@ def _get_sharded_variable(name, shape, dtype, num_shards):
return shards
-class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
+class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
"""Long short-term memory unit (LSTM) recurrent network cell.
The default non-peephole implementation is based on:
@@ -154,14 +152,12 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
self._reuse = reuse
if num_proj:
- self._state_size = (
- core_rnn_cell.LSTMStateTuple(num_units, num_proj)
- if state_is_tuple else num_units + num_proj)
+ self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
+ if state_is_tuple else num_units + num_proj)
self._output_size = num_proj
else:
- self._state_size = (
- core_rnn_cell.LSTMStateTuple(num_units, num_units)
- if state_is_tuple else 2 * num_units)
+ self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)
+ if state_is_tuple else 2 * num_units)
self._output_size = num_units
@property
@@ -254,12 +250,12 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
# pylint: enable=invalid-unary-operand-type
- new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else
- array_ops.concat([c, m], 1))
+ new_state = (rnn_cell_impl.LSTMStateTuple(c, m)
+ if self._state_is_tuple else array_ops.concat([c, m], 1))
return m, new_state
-class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
+class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
"""Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
This implementation is based on:
@@ -427,7 +423,7 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
return freq_inputs
-class GridLSTMCell(core_rnn_cell.RNNCell):
+class GridLSTMCell(rnn_cell_impl.RNNCell):
"""Grid Long short-term memory unit (LSTM) recurrent network cell.
The default is based on:
@@ -1020,11 +1016,11 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
# pylint: disable=protected-access
-_linear = core_rnn_cell_impl._linear
+_linear = rnn_cell_impl._linear
# pylint: enable=protected-access
-class AttentionCellWrapper(core_rnn_cell.RNNCell):
+class AttentionCellWrapper(rnn_cell_impl.RNNCell):
"""Basic attention cell wrapper.
Implementation based on https://arxiv.org/abs/1409.0473.
@@ -1155,7 +1151,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
return new_attns, new_attn_states
-class HighwayWrapper(core_rnn_cell.RNNCell):
+class HighwayWrapper(rnn_cell_impl.RNNCell):
"""RNNCell wrapper that adds highway connection on cell input and output.
Based on:
@@ -1238,7 +1234,7 @@ class HighwayWrapper(core_rnn_cell.RNNCell):
return (res_outputs, new_state)
-class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
+class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
"""LSTM unit with layer normalization and recurrent dropout.
This class adds layer normalization and recurrent dropout to a
@@ -1300,7 +1296,7 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
@property
def state_size(self):
- return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
+ return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
@property
def output_size(self):
@@ -1350,11 +1346,11 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
new_c = self._norm(new_c, "state")
new_h = self._activation(new_c) * math_ops.sigmoid(o)
- new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
return new_h, new_state
-class NASCell(core_rnn_cell.RNNCell):
+class NASCell(rnn_cell_impl.RNNCell):
"""Neural Architecture Search (NAS) recurrent network cell.
This implements the recurrent cell from the paper:
@@ -1388,10 +1384,10 @@ class NASCell(core_rnn_cell.RNNCell):
self._reuse = reuse
if num_proj is not None:
- self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj)
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
self._output_size = num_proj
else:
- self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units)
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
self._output_size = num_units
@property
@@ -1498,11 +1494,11 @@ class NASCell(core_rnn_cell.RNNCell):
dtype)
new_m = math_ops.matmul(new_m, concat_w_proj)
- new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
return new_m, new_state
-class UGRNNCell(core_rnn_cell.RNNCell):
+class UGRNNCell(rnn_cell_impl.RNNCell):
"""Update Gate Recurrent Neural Network (UGRNN) cell.
Compromise between a LSTM/GRU and a vanilla RNN. There is only one
@@ -1589,7 +1585,7 @@ class UGRNNCell(core_rnn_cell.RNNCell):
return new_output, new_state
-class IntersectionRNNCell(core_rnn_cell.RNNCell):
+class IntersectionRNNCell(rnn_cell_impl.RNNCell):
"""Intersection Recurrent Neural Network (+RNN) cell.
Architecture with coupled recurrent gate as well as coupled depth
@@ -1712,7 +1708,7 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
_REGISTERED_OPS = None
-class CompiledWrapper(core_rnn_cell.RNNCell):
+class CompiledWrapper(rnn_cell_impl.RNNCell):
"""Wraps step execution in an XLA JIT scope."""
def __init__(self, cell, compile_stateful=False):
@@ -1783,7 +1779,7 @@ def _random_exp_initializer(minval,
return _initializer
-class PhasedLSTMCell(core_rnn_cell.RNNCell):
+class PhasedLSTMCell(rnn_cell_impl.RNNCell):
"""Phased LSTM recurrent network cell.
https://arxiv.org/pdf/1610.09513v1.pdf
@@ -1831,7 +1827,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
@property
def state_size(self):
- return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
+ return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
@property
def output_size(self):
@@ -1858,13 +1854,13 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
It stores the time.
The second Tensor has shape [batch, features_size], and type float32.
It stores the features.
- state: core_rnn_cell.LSTMStateTuple, state from previous timestep.
+ state: rnn_cell_impl.LSTMStateTuple, state from previous timestep.
Returns:
A tuple containing:
- A Tensor of float32, and shape [batch_size, num_units], representing the
output of the cell.
- - A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape
+ - A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape
[batch_size, num_units], representing the new state and the output.
"""
(c_prev, h_prev) = state
@@ -1921,12 +1917,12 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
new_c = k * new_c + (1 - k) * c_prev
new_h = k * new_h + (1 - k) * h_prev
- new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
return new_h, new_state
-class GLSTMCell(core_rnn_cell.RNNCell):
+class GLSTMCell(rnn_cell_impl.RNNCell):
"""Group LSTM cell (G-LSTM).
The implementation is based on:
@@ -1982,10 +1978,10 @@ class GLSTMCell(core_rnn_cell.RNNCell):
int(self._num_units / self._number_of_groups)]
if num_proj:
- self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj)
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
self._output_size = num_proj
else:
- self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units)
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
self._output_size = num_units
@property
@@ -2097,5 +2093,5 @@ class GLSTMCell(core_rnn_cell.RNNCell):
with vs.variable_scope("projection"):
m = _linear(m, self._num_proj, bias=False)
- new_state = core_rnn_cell.LSTMStateTuple(c, m)
+ new_state = rnn_cell_impl.LSTMStateTuple(c, m)
return m, new_state
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index b8b420e10a..ea34333360 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -24,13 +24,13 @@ import functools
import numpy as np
-from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import test
@@ -41,7 +41,7 @@ from tensorflow.python.util import nest
# for testing
AttentionWrapperState = wrapper.AttentionWrapperState # pylint: disable=invalid-name
-LSTMStateTuple = core_rnn_cell.LSTMStateTuple # pylint: disable=invalid-name
+LSTMStateTuple = rnn_cell.LSTMStateTuple # pylint: disable=invalid-name
BasicDecoderOutput = basic_decoder.BasicDecoderOutput # pylint: disable=invalid-name
float32 = np.float32
int32 = np.int32
@@ -112,7 +112,7 @@ class AttentionWrapperTest(test.TestCase):
with vs.variable_scope(
'root',
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
- cell = core_rnn_cell.LSTMCell(cell_depth)
+ cell = rnn_cell.LSTMCell(cell_depth)
cell = wrapper.AttentionWrapper(
cell,
attention_mechanism,
@@ -133,7 +133,7 @@ class AttentionWrapperTest(test.TestCase):
self.assertTrue(
isinstance(final_state, wrapper.AttentionWrapperState))
self.assertTrue(
- isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple))
+ isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
self.assertEqual((batch_size, None, attention_depth),
tuple(final_outputs.rnn_output.get_shape().as_list()))
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py
index 8fc4ecfc82..600adea189 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py
@@ -21,13 +21,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import core as layers_core
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top
@@ -46,7 +46,7 @@ class BasicDecoderTest(test.TestCase):
with self.test_session(use_gpu=True) as sess:
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
- cell = core_rnn_cell.LSTMCell(cell_depth)
+ cell = rnn_cell.LSTMCell(cell_depth)
helper = helper_py.TrainingHelper(
inputs, sequence_length, time_major=False)
if use_output_layer:
@@ -77,8 +77,8 @@ class BasicDecoderTest(test.TestCase):
constant_op.constant(0), first_inputs, first_state)
batch_size_t = my_decoder.batch_size
- self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
- self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
self.assertTrue(
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
self.assertEqual((batch_size, expected_output_depth),
@@ -130,7 +130,7 @@ class BasicDecoderTest(test.TestCase):
with self.test_session(use_gpu=True) as sess:
embeddings = np.random.randn(vocabulary_size,
input_depth).astype(np.float32)
- cell = core_rnn_cell.LSTMCell(vocabulary_size)
+ cell = rnn_cell.LSTMCell(vocabulary_size)
helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens,
end_token)
my_decoder = basic_decoder.BasicDecoder(
@@ -154,8 +154,8 @@ class BasicDecoderTest(test.TestCase):
constant_op.constant(0), first_inputs, first_state)
batch_size_t = my_decoder.batch_size
- self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
- self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
self.assertTrue(
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
@@ -202,7 +202,7 @@ class BasicDecoderTest(test.TestCase):
embeddings = np.random.randn(
vocabulary_size, input_depth).astype(np.float32)
half = constant_op.constant(0.5)
- cell = core_rnn_cell.LSTMCell(vocabulary_size)
+ cell = rnn_cell.LSTMCell(vocabulary_size)
helper = helper_py.ScheduledEmbeddingTrainingHelper(
inputs=inputs,
sequence_length=sequence_length,
@@ -230,8 +230,8 @@ class BasicDecoderTest(test.TestCase):
constant_op.constant(0), first_inputs, first_state)
batch_size_t = my_decoder.batch_size
- self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
- self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
self.assertTrue(
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
self.assertEqual((batch_size, vocabulary_size),
@@ -293,7 +293,7 @@ class BasicDecoderTest(test.TestCase):
with self.test_session(use_gpu=True) as sess:
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
- cell = core_rnn_cell.LSTMCell(cell_depth)
+ cell = rnn_cell.LSTMCell(cell_depth)
sampling_probability = constant_op.constant(sampling_probability)
next_input_layer = None
@@ -335,8 +335,8 @@ class BasicDecoderTest(test.TestCase):
batch_size_t = my_decoder.batch_size
- self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
- self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
self.assertTrue(
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index cb0cb4f8c3..873a39154f 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
@@ -32,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -241,7 +241,7 @@ class BeamSearchDecoderTest(test.TestCase):
with self.test_session() as sess:
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
- cell = core_rnn_cell.LSTMCell(cell_depth)
+ cell = rnn_cell.LSTMCell(cell_depth)
if has_attention:
inputs = np.random.randn(batch_size, decoder_max_time,
input_depth).astype(np.float32)
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
index 96dc7b4bee..ac830ae98e 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
@@ -21,12 +21,12 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import test
@@ -51,7 +51,7 @@ class DynamicDecodeRNNTest(test.TestCase):
else:
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
- cell = core_rnn_cell.LSTMCell(cell_depth)
+ cell = rnn_cell.LSTMCell(cell_depth)
helper = helper_py.TrainingHelper(
inputs, sequence_length, time_major=time_major)
my_decoder = basic_decoder.BasicDecoder(
@@ -71,7 +71,7 @@ class DynamicDecodeRNNTest(test.TestCase):
self.assertTrue(
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
- self.assertTrue(isinstance(final_state, core_rnn_cell.LSTMStateTuple))
+ self.assertTrue(isinstance(final_state, rnn_cell.LSTMStateTuple))
self.assertEqual(
(batch_size,),
@@ -126,7 +126,7 @@ class DynamicDecodeRNNTest(test.TestCase):
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
- cell = core_rnn_cell.LSTMCell(cell_depth)
+ cell = rnn_cell.LSTMCell(cell_depth)
zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)
helper = helper_py.TrainingHelper(inputs, sequence_length)
my_decoder = basic_decoder.BasicDecoder(
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 79940c3362..bdf47a7b2c 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import collections
import math
-from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -500,7 +499,7 @@ def hardmax(logits, name=None):
math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
-class AttentionWrapper(core_rnn_cell.RNNCell):
+class AttentionWrapper(rnn_cell_impl.RNNCell):
"""Wraps another `RNNCell` with attention.
"""
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py
index ed26f001c2..1234b15199 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py
@@ -17,13 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import gradient_descent
@@ -55,7 +55,7 @@ def BuildFullModel():
with variable_scope.variable_scope('inp_%d' % i):
seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
- cell = BasicRNNCell(16, 48)
+ cell = rnn_cell.BasicRNNCell(16)
out = rnn.dynamic_rnn(
cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
@@ -63,5 +63,3 @@ def BuildFullModel():
loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
return sgd_op.minimize(loss)
-
-
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index a20b86a235..97a1fab1ee 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1709,14 +1709,23 @@ py_library(
py_library(
name = "rnn_cell",
srcs = [
+ "ops/rnn_cell.py",
"ops/rnn_cell_impl.py",
],
srcs_version = "PY2AND3",
deps = [
":array_ops",
+ ":clip_ops",
":framework_for_generated_wrappers",
+ ":init_ops",
":layers_base",
+ ":math_ops",
+ ":nn_ops",
+ ":partitioned_variables",
+ ":random_ops",
":util",
+ ":variable_scope",
+ ":variables",
],
)
diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py
index deb8249343..ca223ef895 100644
--- a/tensorflow/python/debug/lib/session_debug_testlib.py
+++ b/tensorflow/python/debug/lib/session_debug_testlib.py
@@ -53,7 +53,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
-class _RNNCellForTest(rnn_cell_impl._RNNCell): # pylint: disable=protected-access
+class _RNNCellForTest(rnn_cell_impl.RNNCell): # pylint: disable=protected-access
"""RNN cell for testing."""
def __init__(self, input_output_size, state_size):
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 934cef8d6f..a644e6a44f 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -43,7 +43,7 @@ import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
-class Plus1RNNCell(rnn_cell_impl._RNNCell):
+class Plus1RNNCell(rnn_cell_impl.RNNCell):
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
@property
@@ -58,7 +58,7 @@ class Plus1RNNCell(rnn_cell_impl._RNNCell):
return (input_ + 1, state + 1)
-class ScalarStateRNNCell(rnn_cell_impl._RNNCell):
+class ScalarStateRNNCell(rnn_cell_impl.RNNCell):
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
@property
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 7b6494e0c9..d05cba2e93 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -78,6 +78,9 @@ See the @{$python/nn} guide.
@@dynamic_rnn
@@bidirectional_dynamic_rnn
@@raw_rnn
+@@static_rnn
+@@static_state_saving_rnn
+@@static_bidirectional_rnn
@@ctc_loss
@@ctc_greedy_decoder
@@ctc_beam_search_decoder
@@ -113,14 +116,15 @@ from tensorflow.python.util.all_util import remove_undocumented
# Bring more nn-associated functionality into this package.
# go/tf-wildcard-import
-# pylint: disable=wildcard-import
+# pylint: disable=wildcard-import,unused-import
from tensorflow.python.ops.ctc_ops import *
from tensorflow.python.ops.nn_impl import *
from tensorflow.python.ops.nn_ops import *
from tensorflow.python.ops.candidate_sampling_ops import *
from tensorflow.python.ops.embedding_ops import *
from tensorflow.python.ops.rnn import *
-# pylint: enable=wildcard-import
+from tensorflow.python.ops import rnn_cell
+# pylint: enable=wildcard-import,unused-import
# TODO(cwhipkey): sigmoid and tanh should not be exposed from tf.nn.
@@ -135,6 +139,7 @@ _allowed_symbols = [
"lrn", # Excluded in gen_docs_combined.
"relu_layer", # Excluded in gen_docs_combined.
"xw_plus_b", # Excluded in gen_docs_combined.
+ "rnn_cell", # rnn_cell is a submodule of tf.nn.
]
remove_undocumented(__name__, _allowed_symbols,
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 1120b3c394..ca72734707 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -13,8 +13,16 @@
# limitations under the License.
# ==============================================================================
-"""RNN helpers for TensorFlow models."""
+"""RNN helpers for TensorFlow models.
+
+@@bidirectional_dynamic_rnn
+@@dynamic_rnn
+@@raw_rnn
+@@static_rnn
+@@static_state_saving_rnn
+@@static_bidirectional_rnn
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -1062,3 +1070,351 @@ def raw_rnn(cell, loop_fn,
final_loop_state = None
return (emit_ta, final_state, final_loop_state)
+
+
+def static_rnn(cell,
+ inputs,
+ initial_state=None,
+ dtype=None,
+ sequence_length=None,
+ scope=None):
+ """Creates a recurrent neural network specified by RNNCell `cell`.
+
+ The simplest form of RNN network generated is:
+
+ ```python
+ state = cell.zero_state(...)
+ outputs = []
+ for input_ in inputs:
+ output, state = cell(input_, state)
+ outputs.append(output)
+ return (outputs, state)
+ ```
+ However, a few other options are available:
+
+ An initial state can be provided.
+ If the sequence_length vector is provided, dynamic calculation is performed.
+ This method of calculation does not compute the RNN steps past the maximum
+ sequence length of the minibatch (thus saving computational time),
+ and properly propagates the state at an example's sequence length
+ to the final state output.
+
+ The dynamic calculation performed is, at time `t` for batch row `b`,
+
+ ```python
+ (output, state)(b, t) =
+ (t >= sequence_length(b))
+ ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
+ : cell(input(b, t), state(b, t - 1))
+ ```
+
+ Args:
+ cell: An instance of RNNCell.
+ inputs: A length T list of inputs, each a `Tensor` of shape
+ `[batch_size, input_size]`, or a nested tuple of such elements.
+ initial_state: (optional) An initial state for the RNN.
+ If `cell.state_size` is an integer, this must be
+ a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
+ If `cell.state_size` is a tuple, this should be a tuple of
+ tensors having shapes `[batch_size, s] for s in cell.state_size`.
+ dtype: (optional) The data type for the initial state and expected output.
+ Required if initial_state is not provided or RNN state has a heterogeneous
+ dtype.
+ sequence_length: Specifies the length of each sequence in inputs.
+ An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
+ scope: VariableScope for the created subgraph; defaults to "rnn".
+
+ Returns:
+ A pair (outputs, state) where:
+
+ - outputs is a length T list of outputs (one for each input), or a nested
+ tuple of such elements.
+ - state is the final state
+
+ Raises:
+ TypeError: If `cell` is not an instance of RNNCell.
+ ValueError: If `inputs` is `None` or an empty list, or if the input depth
+ (column size) cannot be inferred from inputs via shape inference.
+ """
+
+ if not _like_rnncell(cell):
+ raise TypeError("cell must be an instance of RNNCell")
+ if not nest.is_sequence(inputs):
+ raise TypeError("inputs must be a sequence")
+ if not inputs:
+ raise ValueError("inputs must not be empty")
+
+ outputs = []
+ # Create a new scope in which the caching device is either
+ # determined by the parent scope, or is set to place the cached
+ # Variable using the same placement as for the rest of the RNN.
+ with vs.variable_scope(scope or "rnn") as varscope:
+ if varscope.caching_device is None:
+ varscope.set_caching_device(lambda op: op.device)
+
+ # Obtain the first sequence of the input
+ first_input = inputs
+ while nest.is_sequence(first_input):
+ first_input = first_input[0]
+
+ # Temporarily avoid EmbeddingWrapper and seq2seq badness
+ # TODO(lukaszkaiser): remove EmbeddingWrapper
+ if first_input.get_shape().ndims != 1:
+
+ input_shape = first_input.get_shape().with_rank_at_least(2)
+ fixed_batch_size = input_shape[0]
+
+ flat_inputs = nest.flatten(inputs)
+ for flat_input in flat_inputs:
+ input_shape = flat_input.get_shape().with_rank_at_least(2)
+ batch_size, input_size = input_shape[0], input_shape[1:]
+ fixed_batch_size.merge_with(batch_size)
+ for i, size in enumerate(input_size):
+ if size.value is None:
+ raise ValueError(
+ "Input size (dimension %d of inputs) must be accessible via "
+ "shape inference, but saw value None." % i)
+ else:
+ fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]
+
+ if fixed_batch_size.value:
+ batch_size = fixed_batch_size.value
+ else:
+ batch_size = array_ops.shape(first_input)[0]
+ if initial_state is not None:
+ state = initial_state
+ else:
+ if not dtype:
+ raise ValueError("If no initial_state is provided, "
+ "dtype must be specified")
+ state = cell.zero_state(batch_size, dtype)
+
+ if sequence_length is not None: # Prepare variables
+ sequence_length = ops.convert_to_tensor(
+ sequence_length, name="sequence_length")
+ if sequence_length.get_shape().ndims not in (None, 1):
+ raise ValueError(
+ "sequence_length must be a vector of length batch_size")
+
+ def _create_zero_output(output_size):
+ # convert int to TensorShape if necessary
+ size = _concat(batch_size, output_size)
+ output = array_ops.zeros(
+ array_ops.stack(size), _infer_state_dtype(dtype, state))
+ shape = _concat(fixed_batch_size.value, output_size, static=True)
+ output.set_shape(tensor_shape.TensorShape(shape))
+ return output
+
+ output_size = cell.output_size
+ flat_output_size = nest.flatten(output_size)
+ flat_zero_output = tuple(
+ _create_zero_output(size) for size in flat_output_size)
+ zero_output = nest.pack_sequence_as(
+ structure=output_size, flat_sequence=flat_zero_output)
+
+ sequence_length = math_ops.to_int32(sequence_length)
+ min_sequence_length = math_ops.reduce_min(sequence_length)
+ max_sequence_length = math_ops.reduce_max(sequence_length)
+
+ for time, input_ in enumerate(inputs):
+ if time > 0:
+ varscope.reuse_variables()
+ # pylint: disable=cell-var-from-loop
+ call_cell = lambda: cell(input_, state)
+ # pylint: enable=cell-var-from-loop
+ if sequence_length is not None:
+ (output, state) = _rnn_step(
+ time=time,
+ sequence_length=sequence_length,
+ min_sequence_length=min_sequence_length,
+ max_sequence_length=max_sequence_length,
+ zero_output=zero_output,
+ state=state,
+ call_cell=call_cell,
+ state_size=cell.state_size)
+ else:
+ (output, state) = call_cell()
+
+ outputs.append(output)
+
+ return (outputs, state)
+
+
+def static_state_saving_rnn(cell,
+ inputs,
+ state_saver,
+ state_name,
+ sequence_length=None,
+ scope=None):
+ """RNN that accepts a state saver for time-truncated RNN calculation.
+
+ Args:
+ cell: An instance of `RNNCell`.
+ inputs: A length T list of inputs, each a `Tensor` of shape
+ `[batch_size, input_size]`.
+ state_saver: A state saver object with methods `state` and `save_state`.
+ state_name: Python string or tuple of strings. The name to use with the
+ state_saver. If the cell returns tuples of states (i.e.,
+ `cell.state_size` is a tuple) then `state_name` should be a tuple of
+ strings having the same length as `cell.state_size`. Otherwise it should
+ be a single string.
+ sequence_length: (optional) An int32/int64 vector size [batch_size].
+ See the documentation for rnn() for more details about sequence_length.
+ scope: VariableScope for the created subgraph; defaults to "rnn".
+
+ Returns:
+ A pair (outputs, state) where:
+ outputs is a length T list of outputs (one for each input)
+ states is the final state
+
+ Raises:
+ TypeError: If `cell` is not an instance of RNNCell.
+ ValueError: If `inputs` is `None` or an empty list, or if the arity and
+ type of `state_name` does not match that of `cell.state_size`.
+ """
+ state_size = cell.state_size
+ state_is_tuple = nest.is_sequence(state_size)
+ state_name_tuple = nest.is_sequence(state_name)
+
+ if state_is_tuple != state_name_tuple:
+ raise ValueError("state_name should be the same type as cell.state_size. "
+ "state_name: %s, cell.state_size: %s" % (str(state_name),
+ str(state_size)))
+
+ if state_is_tuple:
+ state_name_flat = nest.flatten(state_name)
+ state_size_flat = nest.flatten(state_size)
+
+ if len(state_name_flat) != len(state_size_flat):
+ raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" %
+ (len(state_name_flat), len(state_size_flat)))
+
+ initial_state = nest.pack_sequence_as(
+ structure=state_size,
+ flat_sequence=[state_saver.state(s) for s in state_name_flat])
+ else:
+ initial_state = state_saver.state(state_name)
+
+ (outputs, state) = static_rnn(
+ cell,
+ inputs,
+ initial_state=initial_state,
+ sequence_length=sequence_length,
+ scope=scope)
+
+ if state_is_tuple:
+ flat_state = nest.flatten(state)
+ state_name = nest.flatten(state_name)
+ save_state = [
+ state_saver.save_state(name, substate)
+ for name, substate in zip(state_name, flat_state)
+ ]
+ else:
+ save_state = [state_saver.save_state(state_name, state)]
+
+ with ops.control_dependencies(save_state):
+ last_output = outputs[-1]
+ flat_last_output = nest.flatten(last_output)
+ flat_last_output = [
+ array_ops.identity(output) for output in flat_last_output
+ ]
+ outputs[-1] = nest.pack_sequence_as(
+ structure=last_output, flat_sequence=flat_last_output)
+
+ return (outputs, state)
+
+
+def static_bidirectional_rnn(cell_fw,
+ cell_bw,
+ inputs,
+ initial_state_fw=None,
+ initial_state_bw=None,
+ dtype=None,
+ sequence_length=None,
+ scope=None):
+ """Creates a bidirectional recurrent neural network.
+
+ Similar to the unidirectional case above (rnn) but takes input and builds
+ independent forward and backward RNNs with the final forward and backward
+ outputs depth-concatenated, such that the output will have the format
+ [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
+ forward and backward cell must match. The initial state for both directions
+ is zero by default (but can be set optionally) and no intermediate states are
+ ever returned -- the network is fully unrolled for the given (passed in)
+ length(s) of the sequence(s) or completely unrolled if length(s) is not given.
+
+ Args:
+ cell_fw: An instance of RNNCell, to be used for forward direction.
+ cell_bw: An instance of RNNCell, to be used for backward direction.
+ inputs: A length T list of inputs, each a tensor of shape
+ [batch_size, input_size], or a nested tuple of such elements.
+ initial_state_fw: (optional) An initial state for the forward RNN.
+ This must be a tensor of appropriate type and shape
+ `[batch_size, cell_fw.state_size]`.
+ If `cell_fw.state_size` is a tuple, this should be a tuple of
+ tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
+ initial_state_bw: (optional) Same as for `initial_state_fw`, but using
+ the corresponding properties of `cell_bw`.
+ dtype: (optional) The data type for the initial state. Required if
+ either of the initial states are not provided.
+ sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
+ containing the actual lengths for each of the sequences.
+ scope: VariableScope for the created subgraph; defaults to
+ "bidirectional_rnn"
+
+ Returns:
+ A tuple (outputs, output_state_fw, output_state_bw) where:
+ outputs is a length `T` list of outputs (one for each input), which
+ are depth-concatenated forward and backward outputs.
+ output_state_fw is the final state of the forward rnn.
+ output_state_bw is the final state of the backward rnn.
+
+ Raises:
+ TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
+ ValueError: If inputs is None or an empty list.
+ """
+
+ if not _like_rnncell(cell_fw):
+ raise TypeError("cell_fw must be an instance of RNNCell")
+ if not _like_rnncell(cell_bw):
+ raise TypeError("cell_bw must be an instance of RNNCell")
+ if not nest.is_sequence(inputs):
+ raise TypeError("inputs must be a sequence")
+ if not inputs:
+ raise ValueError("inputs must not be empty")
+
+ with vs.variable_scope(scope or "bidirectional_rnn"):
+ # Forward direction
+ with vs.variable_scope("fw") as fw_scope:
+ output_fw, output_state_fw = static_rnn(
+ cell_fw,
+ inputs,
+ initial_state_fw,
+ dtype,
+ sequence_length,
+ scope=fw_scope)
+
+ # Backward direction
+ with vs.variable_scope("bw") as bw_scope:
+ reversed_inputs = _reverse_seq(inputs, sequence_length)
+ tmp, output_state_bw = static_rnn(
+ cell_bw,
+ reversed_inputs,
+ initial_state_bw,
+ dtype,
+ sequence_length,
+ scope=bw_scope)
+
+ output_bw = _reverse_seq(tmp, sequence_length)
+ # Concat each of the forward/backward outputs
+ flat_output_fw = nest.flatten(output_fw)
+ flat_output_bw = nest.flatten(output_bw)
+
+ flat_outputs = tuple(
+ array_ops.concat([fw, bw], 1)
+ for fw, bw in zip(flat_output_fw, flat_output_bw))
+
+ outputs = nest.pack_sequence_as(
+ structure=output_fw, flat_sequence=flat_outputs)
+
+ return (outputs, output_state_fw, output_state_bw)
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
new file mode 100644
index 0000000000..c0dac8fb01
--- /dev/null
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -0,0 +1,51 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module for constructing RNN Cells.
+
+## Base interface for all RNN Cells
+
+@@RNNCell
+
+## RNN Cells for use with TensorFlow's core RNN methods
+
+@@BasicRNNCell
+@@BasicLSTMCell
+@@GRUCell
+@@LSTMCell
+
+## Classes storing split `RNNCell` state
+
+@@LSTMStateTuple
+
+## RNN Cell wrappers (RNNCells that wrap other RNNCells)
+
+@@MultiRNNCell
+@@DropoutWrapper
+@@DeviceWrapper
+@@ResidualWrapper
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.rnn_cell_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = []
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 9c0fb1db23..500e3b7859 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -12,18 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Module implementing RNN Cells.
-This module contains the abstract definition of a RNN cell: `_RNNCell`.
-Actual implementations of various types of RNN cells are located in
-`tensorflow.contrib`.
+This module provides a number of basic commonly used RNN cells, such as LSTM
+(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of
+operators that allow adding dropouts, projections, or embeddings for inputs.
+Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by
+calling the `rnn` ops several times.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+import hashlib
+import numbers
+
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -31,11 +35,22 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
+_BIAS_VARIABLE_NAME = "bias"
+_WEIGHTS_VARIABLE_NAME = "kernel"
+
+
def _like_rnncell(cell):
"""Checks that a given object is an RNNCell by using duck typing."""
conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
@@ -115,7 +130,7 @@ def _zero_state_tensors(state_size, batch_size, dtype):
return nest.map_structure(get_state_shape, state_size)
-class _RNNCell(base_layer.Layer):
+class RNNCell(base_layer.Layer):
"""Abstract object representing an RNN cell.
Every `RNNCell` must have the properties below and implement `call` with
@@ -158,11 +173,11 @@ class _RNNCell(base_layer.Layer):
if scope is not None:
with vs.variable_scope(scope,
custom_getter=self._rnn_get_variable) as scope:
- return super(_RNNCell, self).__call__(inputs, state, scope=scope)
+ return super(RNNCell, self).__call__(inputs, state, scope=scope)
else:
with vs.variable_scope(vs.get_variable_scope(),
custom_getter=self._rnn_get_variable):
- return super(_RNNCell, self).__call__(inputs, state)
+ return super(RNNCell, self).__call__(inputs, state)
def _rnn_get_variable(self, getter, *args, **kwargs):
variable = getter(*args, **kwargs)
@@ -212,3 +227,806 @@ class _RNNCell(base_layer.Layer):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
state_size = self.state_size
return _zero_state_tensors(state_size, batch_size, dtype)
+
+
+class BasicRNNCell(RNNCell):
+ """The most basic RNN cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ """
+
+ def __init__(self, num_units, activation=None, reuse=None):
+ super(BasicRNNCell, self).__init__(_reuse=reuse)
+ self._num_units = num_units
+ self._activation = activation or math_ops.tanh
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def call(self, inputs, state):
+ """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
+ output = self._activation(_linear([inputs, state], self._num_units, True))
+ return output, output
+
+
+class GRUCell(RNNCell):
+ """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
+
+ def __init__(self,
+ num_units,
+ activation=None,
+ reuse=None,
+ kernel_initializer=None,
+ bias_initializer=None):
+ super(GRUCell, self).__init__(_reuse=reuse)
+ self._num_units = num_units
+ self._activation = activation or math_ops.tanh
+ self._kernel_initializer = kernel_initializer
+ self._bias_initializer = bias_initializer
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def call(self, inputs, state):
+ """Gated recurrent unit (GRU) with nunits cells."""
+ with vs.variable_scope("gates"): # Reset gate and update gate.
+ # We start with bias of 1.0 to not reset and not update.
+ bias_ones = self._bias_initializer
+ if self._bias_initializer is None:
+ dtype = [a.dtype for a in [inputs, state]][0]
+ bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
+ value = math_ops.sigmoid(
+ _linear([inputs, state], 2 * self._num_units, True, bias_ones,
+ self._kernel_initializer))
+ r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
+ with vs.variable_scope("candidate"):
+ c = self._activation(
+ _linear([inputs, r * state], self._num_units, True,
+ self._bias_initializer, self._kernel_initializer))
+ new_h = u * state + (1 - u) * c
+ return new_h, new_h
+
+
+_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
+
+
+class LSTMStateTuple(_LSTMStateTuple):
+ """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
+
+ Stores two elements: `(c, h)`, in that order.
+
+ Only used when `state_is_tuple=True`.
+ """
+ __slots__ = ()
+
+ @property
+ def dtype(self):
+ (c, h) = self
+ if c.dtype != h.dtype:
+ raise TypeError("Inconsistent internal state: %s vs %s" %
+ (str(c.dtype), str(h.dtype)))
+ return c.dtype
+
+
+class BasicLSTMCell(RNNCell):
+ """Basic LSTM recurrent network cell.
+
+ The implementation is based on: http://arxiv.org/abs/1409.2329.
+
+ We add forget_bias (default: 1) to the biases of the forget gate in order to
+ reduce the scale of forgetting in the beginning of the training.
+
+ It does not allow cell clipping, a projection layer, and does not
+ use peep-hole connections: it is the basic baseline.
+
+ For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
+ that follows.
+ """
+
+ def __init__(self, num_units, forget_bias=1.0,
+ state_is_tuple=True, activation=None, reuse=None):
+ """Initialize the basic LSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell.
+ forget_bias: float, The bias added to forget gates (see above).
+ state_is_tuple: If True, accepted and returned states are 2-tuples of
+ the `c_state` and `m_state`. If False, they are concatenated
+ along the column axis. The latter behavior will soon be deprecated.
+ activation: Activation function of the inner states. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ """
+ super(BasicLSTMCell, self).__init__(_reuse=reuse)
+ if not state_is_tuple:
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
+ "deprecated. Use state_is_tuple=True.", self)
+ self._num_units = num_units
+ self._forget_bias = forget_bias
+ self._state_is_tuple = state_is_tuple
+ self._activation = activation or math_ops.tanh
+
+ @property
+ def state_size(self):
+ return (LSTMStateTuple(self._num_units, self._num_units)
+ if self._state_is_tuple else 2 * self._num_units)
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def call(self, inputs, state):
+ """Long short-term memory cell (LSTM)."""
+ sigmoid = math_ops.sigmoid
+ # Parameters of gates are concatenated into one multiply for efficiency.
+ if self._state_is_tuple:
+ c, h = state
+ else:
+ c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
+
+ concat = _linear([inputs, h], 4 * self._num_units, True)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
+
+ new_c = (
+ c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
+ new_h = self._activation(new_c) * sigmoid(o)
+
+ if self._state_is_tuple:
+ new_state = LSTMStateTuple(new_c, new_h)
+ else:
+ new_state = array_ops.concat([new_c, new_h], 1)
+ return new_h, new_state
+
+
+class LSTMCell(RNNCell):
+ """Long short-term memory unit (LSTM) recurrent network cell.
+
+ The default non-peephole implementation is based on:
+
+ http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
+
+ S. Hochreiter and J. Schmidhuber.
+ "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+ The peephole implementation is based on:
+
+ https://research.google.com/pubs/archive/43905.pdf
+
+ Hasim Sak, Andrew Senior, and Francoise Beaufays.
+ "Long short-term memory recurrent neural network architectures for
+ large scale acoustic modeling." INTERSPEECH, 2014.
+
+ The class uses optional peep-hole connections, optional cell clipping, and
+ an optional projection layer.
+ """
+
+ def __init__(self, num_units,
+ use_peepholes=False, cell_clip=None,
+ initializer=None, num_proj=None, proj_clip=None,
+ num_unit_shards=None, num_proj_shards=None,
+ forget_bias=1.0, state_is_tuple=True,
+ activation=None, reuse=None):
+ """Initialize the parameters for an LSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell
+ use_peepholes: bool, set True to enable diagonal/peephole connections.
+ cell_clip: (optional) A float value, if provided the cell state is clipped
+ by this value prior to the cell output activation.
+ initializer: (optional) The initializer to use for the weight and
+ projection matrices.
+ num_proj: (optional) int, The output dimensionality for the projection
+ matrices. If None, no projection is performed.
+ proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
+ provided, then the projected values are clipped elementwise to within
+ `[-proj_clip, proj_clip]`.
+ num_unit_shards: Deprecated, will be removed by Jan. 2017.
+ Use a variable_scope partitioner instead.
+ num_proj_shards: Deprecated, will be removed by Jan. 2017.
+ Use a variable_scope partitioner instead.
+ forget_bias: Biases of the forget gate are initialized by default to 1
+ in order to reduce the scale of forgetting at the beginning of
+ the training.
+ state_is_tuple: If True, accepted and returned states are 2-tuples of
+ the `c_state` and `m_state`. If False, they are concatenated
+ along the column axis. This latter behavior will soon be deprecated.
+ activation: Activation function of the inner states. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ """
+ super(LSTMCell, self).__init__(_reuse=reuse)
+ if not state_is_tuple:
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
+ "deprecated. Use state_is_tuple=True.", self)
+ if num_unit_shards is not None or num_proj_shards is not None:
+ logging.warn(
+ "%s: The num_unit_shards and proj_unit_shards parameters are "
+ "deprecated and will be removed in Jan 2017. "
+ "Use a variable scope with a partitioner instead.", self)
+
+ self._num_units = num_units
+ self._use_peepholes = use_peepholes
+ self._cell_clip = cell_clip
+ self._initializer = initializer
+ self._num_proj = num_proj
+ self._proj_clip = proj_clip
+ self._num_unit_shards = num_unit_shards
+ self._num_proj_shards = num_proj_shards
+ self._forget_bias = forget_bias
+ self._state_is_tuple = state_is_tuple
+ self._activation = activation or math_ops.tanh
+
+ if num_proj:
+ self._state_size = (
+ LSTMStateTuple(num_units, num_proj)
+ if state_is_tuple else num_units + num_proj)
+ self._output_size = num_proj
+ else:
+ self._state_size = (
+ LSTMStateTuple(num_units, num_units)
+ if state_is_tuple else 2 * num_units)
+ self._output_size = num_units
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ def call(self, inputs, state):
+ """Run one step of LSTM.
+
+ Args:
+ inputs: input Tensor, 2D, batch x num_units.
+ state: if `state_is_tuple` is False, this must be a state Tensor,
+ `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
+ tuple of state Tensors, both `2-D`, with column sizes `c_state` and
+ `m_state`.
+
+ Returns:
+ A tuple containing:
+
+ - A `2-D, [batch x output_dim]`, Tensor representing the output of the
+ LSTM after reading `inputs` when previous state was `state`.
+ Here output_dim is:
+ num_proj if num_proj was set,
+ num_units otherwise.
+ - Tensor(s) representing the new state of LSTM after reading `inputs` when
+ the previous state was `state`. Same type and shape(s) as `state`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ num_proj = self._num_units if self._num_proj is None else self._num_proj
+ sigmoid = math_ops.sigmoid
+
+ if self._state_is_tuple:
+ (c_prev, m_prev) = state
+ else:
+ c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
+ m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
+
+ dtype = inputs.dtype
+ input_size = inputs.get_shape().with_rank(2)[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+ scope = vs.get_variable_scope()
+ with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
+ if self._num_unit_shards is not None:
+ unit_scope.set_partitioner(
+ partitioned_variables.fixed_size_partitioner(
+ self._num_unit_shards))
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True)
+ i, j, f, o = array_ops.split(
+ value=lstm_matrix, num_or_size_splits=4, axis=1)
+ # Diagonal connections
+ if self._use_peepholes:
+ with vs.variable_scope(unit_scope) as projection_scope:
+ if self._num_unit_shards is not None:
+ projection_scope.set_partitioner(None)
+ w_f_diag = vs.get_variable(
+ "w_f_diag", shape=[self._num_units], dtype=dtype)
+ w_i_diag = vs.get_variable(
+ "w_i_diag", shape=[self._num_units], dtype=dtype)
+ w_o_diag = vs.get_variable(
+ "w_o_diag", shape=[self._num_units], dtype=dtype)
+
+ if self._use_peepholes:
+ c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
+ sigmoid(i + w_i_diag * c_prev) * self._activation(j))
+ else:
+ c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
+ self._activation(j))
+
+ if self._cell_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
+ # pylint: enable=invalid-unary-operand-type
+ if self._use_peepholes:
+ m = sigmoid(o + w_o_diag * c) * self._activation(c)
+ else:
+ m = sigmoid(o) * self._activation(c)
+
+ if self._num_proj is not None:
+ with vs.variable_scope("projection") as proj_scope:
+ if self._num_proj_shards is not None:
+ proj_scope.set_partitioner(
+ partitioned_variables.fixed_size_partitioner(
+ self._num_proj_shards))
+ m = _linear(m, self._num_proj, bias=False)
+
+ if self._proj_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
+ # pylint: enable=invalid-unary-operand-type
+
+ new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
+ array_ops.concat([c, m], 1))
+ return m, new_state
+
+
+def _enumerated_map_structure(map_fn, *args, **kwargs):
+ ix = [0]
+ def enumerated_fn(*inner_args, **inner_kwargs):
+ r = map_fn(ix[0], *inner_args, **inner_kwargs)
+ ix[0] += 1
+ return r
+ return nest.map_structure(enumerated_fn, *args, **kwargs)
+
+
+class DropoutWrapper(RNNCell):
+ """Operator adding dropout to inputs and outputs of the given cell."""
+
+ def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
+ state_keep_prob=1.0, variational_recurrent=False,
+ input_size=None, dtype=None, seed=None):
+ """Create a cell with added input, state, and/or output dropout.
+
+ If `variational_recurrent` is set to `True` (**NOT** the default behavior),
+ then the the same dropout mask is applied at every step, as described in:
+
+ Y. Gal, Z Ghahramani. "A Theoretically Grounded Application of Dropout in
+ Recurrent Neural Networks". https://arxiv.org/abs/1512.05287
+
+ Otherwise a different dropout mask is applied at every time step.
+
+ Args:
+ cell: an RNNCell, a projection to output_size is added to it.
+ input_keep_prob: unit Tensor or float between 0 and 1, input keep
+ probability; if it is constant and 1, no input dropout will be added.
+ output_keep_prob: unit Tensor or float between 0 and 1, output keep
+ probability; if it is constant and 1, no output dropout will be added.
+ state_keep_prob: unit Tensor or float between 0 and 1, output keep
+ probability; if it is constant and 1, no output dropout will be added.
+ State dropout is performed on the *output* states of the cell.
+ variational_recurrent: Python bool. If `True`, then the same
+ dropout pattern is applied across all time steps per run call.
+ If this parameter is set, `input_size` **must** be provided.
+ input_size: (optional) (possibly nested tuple of) `TensorShape` objects
+ containing the depth(s) of the input tensors expected to be passed in to
+ the `DropoutWrapper`. Required and used **iff**
+ `variational_recurrent = True` and `input_keep_prob < 1`.
+ dtype: (optional) The `dtype` of the input, state, and output tensors.
+ Required and used **iff** `variational_recurrent = True`.
+ seed: (optional) integer, the randomness seed.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if any of the keep_probs are not between 0 and 1.
+ """
+ if not _like_rnncell(cell):
+ raise TypeError("The parameter cell is not a RNNCell.")
+ with ops.name_scope("DropoutWrapperInit"):
+ def tensor_and_const_value(v):
+ tensor_value = ops.convert_to_tensor(v)
+ const_value = tensor_util.constant_value(tensor_value)
+ return (tensor_value, const_value)
+ for prob, attr in [(input_keep_prob, "input_keep_prob"),
+ (state_keep_prob, "state_keep_prob"),
+ (output_keep_prob, "output_keep_prob")]:
+ tensor_prob, const_prob = tensor_and_const_value(prob)
+ if const_prob is not None:
+ if const_prob < 0 or const_prob > 1:
+ raise ValueError("Parameter %s must be between 0 and 1: %d"
+ % (attr, const_prob))
+ setattr(self, "_%s" % attr, float(const_prob))
+ else:
+ setattr(self, "_%s" % attr, tensor_prob)
+
+ # Set cell, variational_recurrent, seed before running the code below
+ self._cell = cell
+ self._variational_recurrent = variational_recurrent
+ self._seed = seed
+
+ self._recurrent_input_noise = None
+ self._recurrent_state_noise = None
+ self._recurrent_output_noise = None
+
+ if variational_recurrent:
+ if dtype is None:
+ raise ValueError(
+ "When variational_recurrent=True, dtype must be provided")
+
+ def convert_to_batch_shape(s):
+ # Prepend a 1 for the batch dimension; for recurrent
+ # variational dropout we use the same dropout mask for all
+ # batch elements.
+ return array_ops.concat(
+ ([1], tensor_shape.TensorShape(s).as_list()), 0)
+
+ def batch_noise(s, inner_seed):
+ shape = convert_to_batch_shape(s)
+ return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype)
+
+ if (not isinstance(self._input_keep_prob, numbers.Real) or
+ self._input_keep_prob < 1.0):
+ if input_size is None:
+ raise ValueError(
+ "When variational_recurrent=True and input_keep_prob < 1.0 or "
+ "is unknown, input_size must be provided")
+ self._recurrent_input_noise = _enumerated_map_structure(
+ lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)),
+ input_size)
+ self._recurrent_state_noise = _enumerated_map_structure(
+ lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)),
+ cell.state_size)
+ self._recurrent_output_noise = _enumerated_map_structure(
+ lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)),
+ cell.output_size)
+
+ def _gen_seed(self, salt_prefix, index):
+ if self._seed is None:
+ return None
+ salt = "%s_%d" % (salt_prefix, index)
+ string = (str(self._seed) + salt).encode("utf-8")
+ return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def zero_state(self, batch_size, dtype):
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
+ return self._cell.zero_state(batch_size, dtype)
+
+ def _variational_recurrent_dropout_value(
+ self, index, value, noise, keep_prob):
+ """Performs dropout given the pre-calculated noise tensor."""
+ # uniform [keep_prob, 1.0 + keep_prob)
+ random_tensor = keep_prob + noise
+
+ # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
+ binary_tensor = math_ops.floor(random_tensor)
+ ret = math_ops.div(value, keep_prob) * binary_tensor
+ ret.set_shape(value.get_shape())
+ return ret
+
+ def _dropout(self, values, salt_prefix, recurrent_noise, keep_prob):
+ """Decides whether to perform standard dropout or recurrent dropout."""
+ if not self._variational_recurrent:
+ def dropout(i, v):
+ return nn_ops.dropout(
+ v, keep_prob=keep_prob, seed=self._gen_seed(salt_prefix, i))
+ return _enumerated_map_structure(dropout, values)
+ else:
+ def dropout(i, v, n):
+ return self._variational_recurrent_dropout_value(i, v, n, keep_prob)
+ return _enumerated_map_structure(dropout, values, recurrent_noise)
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell with the declared dropouts."""
+ def _should_dropout(p):
+ return (not isinstance(p, float)) or p < 1
+
+ if _should_dropout(self._input_keep_prob):
+ inputs = self._dropout(inputs, "input",
+ self._recurrent_input_noise,
+ self._input_keep_prob)
+ output, new_state = self._cell(inputs, state, scope)
+ if _should_dropout(self._state_keep_prob):
+ new_state = self._dropout(new_state, "state",
+ self._recurrent_state_noise,
+ self._state_keep_prob)
+ if _should_dropout(self._output_keep_prob):
+ output = self._dropout(output, "output",
+ self._recurrent_output_noise,
+ self._output_keep_prob)
+ return output, new_state
+
+
+class ResidualWrapper(RNNCell):
+ """RNNCell wrapper that ensures cell inputs are added to the outputs."""
+
+ def __init__(self, cell):
+ """Constructs a `ResidualWrapper` for `cell`.
+
+ Args:
+ cell: An instance of `RNNCell`.
+ """
+ self._cell = cell
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def zero_state(self, batch_size, dtype):
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
+ return self._cell.zero_state(batch_size, dtype)
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell and add its inputs to its outputs.
+
+ Args:
+ inputs: cell inputs.
+ state: cell state.
+ scope: optional cell scope.
+
+ Returns:
+ Tuple of cell outputs and new state.
+
+ Raises:
+ TypeError: If cell inputs and outputs have different structure (type).
+ ValueError: If cell inputs and outputs have different structure (value).
+ """
+ outputs, new_state = self._cell(inputs, state, scope=scope)
+ nest.assert_same_structure(inputs, outputs)
+ # Ensure shapes match
+ def assert_shape_match(inp, out):
+ inp.get_shape().assert_is_compatible_with(out.get_shape())
+ nest.map_structure(assert_shape_match, inputs, outputs)
+ res_outputs = nest.map_structure(
+ lambda inp, out: inp + out, inputs, outputs)
+ return (res_outputs, new_state)
+
+
+class DeviceWrapper(RNNCell):
+ """Operator that ensures an RNNCell runs on a particular device."""
+
+ def __init__(self, cell, device):
+ """Construct a `DeviceWrapper` for `cell` with device `device`.
+
+ Ensures the wrapped `cell` is called with `tf.device(device)`.
+
+ Args:
+ cell: An instance of `RNNCell`.
+ device: A device string or function, for passing to `tf.device`.
+ """
+ self._cell = cell
+ self._device = device
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def zero_state(self, batch_size, dtype):
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
+ with ops.device(self._device):
+ return self._cell.zero_state(batch_size, dtype)
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell on specified device."""
+ with ops.device(self._device):
+ return self._cell(inputs, state, scope=scope)
+
+
+class MultiRNNCell(RNNCell):
+ """RNN cell composed sequentially of multiple simple cells."""
+
+ def __init__(self, cells, state_is_tuple=True):
+ """Create a RNN cell composed sequentially of a number of RNNCells.
+
+ Args:
+ cells: list of RNNCells that will be composed in this order.
+ state_is_tuple: If True, accepted and returned states are n-tuples, where
+ `n = len(cells)`. If False, the states are all
+ concatenated along the column axis. This latter behavior will soon be
+ deprecated.
+
+ Raises:
+ ValueError: if cells is empty (not allowed), or at least one of the cells
+ returns a state tuple but the flag `state_is_tuple` is `False`.
+ """
+ super(MultiRNNCell, self).__init__()
+ if not cells:
+ raise ValueError("Must specify at least one cell for MultiRNNCell.")
+ if not nest.is_sequence(cells):
+ raise TypeError(
+ "cells must be a list or tuple, but saw: %s." % cells)
+
+ self._cells = cells
+ self._state_is_tuple = state_is_tuple
+ if not state_is_tuple:
+ if any(nest.is_sequence(c.state_size) for c in self._cells):
+ raise ValueError("Some cells return tuples of states, but the flag "
+ "state_is_tuple is not set. State sizes are: %s"
+ % str([c.state_size for c in self._cells]))
+
+ @property
+ def state_size(self):
+ if self._state_is_tuple:
+ return tuple(cell.state_size for cell in self._cells)
+ else:
+ return sum([cell.state_size for cell in self._cells])
+
+ @property
+ def output_size(self):
+ return self._cells[-1].output_size
+
+ def zero_state(self, batch_size, dtype):
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
+ if self._state_is_tuple:
+ return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
+ else:
+ # We know here that state_size of each cell is not a tuple and
+ # presumably does not contain TensorArrays or anything else fancy
+ return super(MultiRNNCell, self).zero_state(batch_size, dtype)
+
+ def call(self, inputs, state):
+ """Run this multi-layer cell on inputs, starting from state."""
+ cur_state_pos = 0
+ cur_inp = inputs
+ new_states = []
+ for i, cell in enumerate(self._cells):
+ with vs.variable_scope("cell_%d" % i):
+ if self._state_is_tuple:
+ if not nest.is_sequence(state):
+ raise ValueError(
+ "Expected state to be a tuple of length %d, but received: %s" %
+ (len(self.state_size), state))
+ cur_state = state[i]
+ else:
+ cur_state = array_ops.slice(state, [0, cur_state_pos],
+ [-1, cell.state_size])
+ cur_state_pos += cell.state_size
+ cur_inp, new_state = cell(cur_inp, cur_state)
+ new_states.append(new_state)
+
+ new_states = (tuple(new_states) if self._state_is_tuple else
+ array_ops.concat(new_states, 1))
+
+ return cur_inp, new_states
+
+
+class _SlimRNNCell(RNNCell):
+ """A simple wrapper for slim.rnn_cells."""
+
+ def __init__(self, cell_fn):
+ """Create a SlimRNNCell from a cell_fn.
+
+ Args:
+ cell_fn: a function which takes (inputs, state, scope) and produces the
+ outputs and the new_state. Additionally when called with inputs=None and
+ state=None it should return (initial_outputs, initial_state).
+
+ Raises:
+ TypeError: if cell_fn is not callable
+ ValueError: if cell_fn cannot produce a valid initial state.
+ """
+ if not callable(cell_fn):
+ raise TypeError("cell_fn %s needs to be callable", cell_fn)
+ self._cell_fn = cell_fn
+ self._cell_name = cell_fn.func.__name__
+ init_output, init_state = self._cell_fn(None, None)
+ output_shape = init_output.get_shape()
+ state_shape = init_state.get_shape()
+ self._output_size = output_shape.with_rank(2)[1].value
+ self._state_size = state_shape.with_rank(2)[1].value
+ if self._output_size is None:
+ raise ValueError("Initial output created by %s has invalid shape %s" %
+ (self._cell_name, output_shape))
+ if self._state_size is None:
+ raise ValueError("Initial state created by %s has invalid shape %s" %
+ (self._cell_name, state_shape))
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ def __call__(self, inputs, state, scope=None):
+ scope = scope or self._cell_name
+ output, state = self._cell_fn(inputs, state, scope=scope)
+ return output, state
+
+
+def _linear(args,
+ output_size,
+ bias,
+ bias_initializer=None,
+ kernel_initializer=None):
+ """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
+
+ Args:
+ args: a 2D Tensor or a list of 2D, batch x n, Tensors.
+ output_size: int, second dimension of W[i].
+ bias: boolean, whether to add a bias term or not.
+ bias_initializer: starting value to initialize the bias
+ (default is all zeros).
+ kernel_initializer: starting value to initialize the weight.
+
+ Returns:
+ A 2D Tensor with shape [batch x output_size] equal to
+ sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
+
+ Raises:
+ ValueError: if some of the arguments has unspecified or wrong shape.
+ """
+ if args is None or (nest.is_sequence(args) and not args):
+ raise ValueError("`args` must be specified")
+ if not nest.is_sequence(args):
+ args = [args]
+
+ # Calculate the total size of arguments on dimension 1.
+ total_arg_size = 0
+ shapes = [a.get_shape() for a in args]
+ for shape in shapes:
+ if shape.ndims != 2:
+ raise ValueError("linear is expecting 2D arguments: %s" % shapes)
+ if shape[1].value is None:
+ raise ValueError("linear expects shape[1] to be provided for shape %s, "
+ "but saw %s" % (shape, shape[1]))
+ else:
+ total_arg_size += shape[1].value
+
+ dtype = [a.dtype for a in args][0]
+
+ # Now the computation.
+ scope = vs.get_variable_scope()
+ with vs.variable_scope(scope) as outer_scope:
+ weights = vs.get_variable(
+ _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size],
+ dtype=dtype,
+ initializer=kernel_initializer)
+ if len(args) == 1:
+ res = math_ops.matmul(args[0], weights)
+ else:
+ res = math_ops.matmul(array_ops.concat(args, 1), weights)
+ if not bias:
+ return res
+ with vs.variable_scope(outer_scope) as inner_scope:
+ inner_scope.set_partitioner(None)
+ if bias_initializer is None:
+ bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
+ biases = vs.get_variable(
+ _BIAS_VARIABLE_NAME, [output_size],
+ dtype=dtype,
+ initializer=bias_initializer)
+ return nn_ops.bias_add(res, biases)
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
index b1b60fbdcb..9f817beafd 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
@@ -1,5 +1,9 @@
path: "tensorflow.nn"
tf_module {
+ member {
+ name: "rnn_cell"
+ mtype: "<type \'module\'>"
+ }
member_method {
name: "all_candidate_sampler"
argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
@@ -285,6 +289,18 @@ tf_module {
argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "static_bidirectional_rnn"
+ argspec: "args=[\'cell_fw\', \'cell_bw\', \'inputs\', \'initial_state_fw\', \'initial_state_bw\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "static_rnn"
+ argspec: "args=[\'cell\', \'inputs\', \'initial_state\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "static_state_saving_rnn"
+ argspec: "args=[\'cell\', \'inputs\', \'state_saver\', \'state_name\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "sufficient_statistics"
argspec: "args=[\'x\', \'axes\', \'shift\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
new file mode 100644
index 0000000000..fbf68c50a1
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
@@ -0,0 +1,95 @@
+path: "tensorflow.nn.rnn_cell.BasicLSTMCell"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "state_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'num_units\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'1.0\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zero_state"
+ argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
new file mode 100644
index 0000000000..606d20d8f0
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
@@ -0,0 +1,95 @@
+path: "tensorflow.nn.rnn_cell.BasicRNNCell"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicRNNCell\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "state_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zero_state"
+ argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
new file mode 100644
index 0000000000..ead1d0cfc5
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
@@ -0,0 +1,95 @@
+path: "tensorflow.nn.rnn_cell.DeviceWrapper"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DeviceWrapper\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "state_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'cell\', \'device\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zero_state"
+ argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
new file mode 100644
index 0000000000..2db4996b2a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
@@ -0,0 +1,95 @@
+path: "tensorflow.nn.rnn_cell.DropoutWrapper"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapper\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "state_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'cell\', \'input_keep_prob\', \'output_keep_prob\', \'state_keep_prob\', \'variational_recurrent\', \'input_size\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'1.0\', \'False\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zero_state"
+ argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
new file mode 100644
index 0000000000..101f6df1d8
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
@@ -0,0 +1,95 @@
+path: "tensorflow.nn.rnn_cell.GRUCell"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.GRUCell\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "state_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'kernel_initializer\', \'bias_initializer\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zero_state"
+ argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
new file mode 100644
index 0000000000..c87546d528
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
@@ -0,0 +1,95 @@
+path: "tensorflow.nn.rnn_cell.LSTMCell"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LSTMCell\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "state_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'num_units\', \'use_peepholes\', \'cell_clip\', \'initializer\', \'num_proj\', \'proj_clip\', \'num_unit_shards\', \'num_proj_shards\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1.0\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zero_state"
+ argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt
new file mode 100644
index 0000000000..1de8a55dcc
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt
@@ -0,0 +1,27 @@
+path: "tensorflow.nn.rnn_cell.LSTMStateTuple"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple\'>"
+ is_instance: "<type \'tuple\'>"
+ member {
+ name: "c"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "h"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "count"
+ }
+ member_method {
+ name: "index"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
new file mode 100644
index 0000000000..bc01ccfa64
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
@@ -0,0 +1,95 @@
+path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "state_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zero_state"
+ argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
new file mode 100644
index 0000000000..b19ee18b40
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
@@ -0,0 +1,94 @@
+path: "tensorflow.nn.rnn_cell.RNNCell"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "state_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'trainable\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'True\', \'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zero_state"
+ argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
new file mode 100644
index 0000000000..b21d9a8ee3
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
@@ -0,0 +1,95 @@
+path: "tensorflow.nn.rnn_cell.ResidualWrapper"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapper\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "state_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'cell\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zero_state"
+ argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt
new file mode 100644
index 0000000000..64697e8a02
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt
@@ -0,0 +1,43 @@
+path: "tensorflow.nn.rnn_cell"
+tf_module {
+ member {
+ name: "BasicLSTMCell"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "BasicRNNCell"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "DeviceWrapper"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "DropoutWrapper"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "GRUCell"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "LSTMCell"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "LSTMStateTuple"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "MultiRNNCell"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "RNNCell"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "ResidualWrapper"
+ mtype: "<type \'type\'>"
+ }
+}