aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn_cell_impl.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-12-13 11:17:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-13 11:20:43 -0800
commit52a44f28174f3a08fa92c3d43a9531c7c1101666 (patch)
tree19bb5adadde3e9ff6d6c63366c9d550a245bf437 /tensorflow/python/ops/rnn_cell_impl.py
parentdcbf6c972d7b4203735bca04f4d33d575ef7b22b (diff)
Convert LSTMFusedBlockCell to a plain Layer; it is not really an RNNCell.
This allows us to revert a change to the public API for most RNNCells. That breaking change was introduced yesterday (wherein scope argument had to be passed by keyword arg). PiperOrigin-RevId: 178930316
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 7c759d852c..7cb9f7762d 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -265,18 +265,18 @@ class _LayerRNNCell(RNNCell):
`call` methods do not access Variables `tf.get_variable`.
"""
- def __call__(self, inputs, *args, **kwargs):
+ def __call__(self, inputs, state, scope=None, *args, **kwargs):
"""Run this RNN cell on inputs, starting from the given state.
Args:
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
- *args: Additional positional arguments.
- Usually composesed of `[state]`: if `self.state_size` is an integer,
- this should be a `2-D Tensor` with shape
- `[batch_size, self.state_size]`. Otherwise, if
+ state: if `self.state_size` is an integer, this should be a `2-D Tensor`
+ with shape `[batch_size, self.state_size]`. Otherwise, if
`self.state_size` is a tuple of integers, this should be a tuple
with shapes `[batch_size, s] for s in self.state_size`.
- **kwargs: Additional keyword arguments. Common keys include `scope`.
+ scope: optional cell scope.
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
Returns:
A pair containing:
@@ -288,7 +288,8 @@ class _LayerRNNCell(RNNCell):
# Bypass RNNCell's variable capturing semantics for LayerRNNCell.
# Instead, it is up to subclasses to provide a proper build
# method. See the class docstring for more details.
- return base_layer.Layer.__call__(self, inputs, *args, **kwargs)
+ return base_layer.Layer.__call__(self, inputs, state, scope=scope,
+ *args, **kwargs)
class BasicRNNCell(_LayerRNNCell):