diff options
author | 2017-12-13 11:17:07 -0800 | |
---|---|---|
committer | 2017-12-13 11:20:43 -0800 | |
commit | 52a44f28174f3a08fa92c3d43a9531c7c1101666 (patch) | |
tree | 19bb5adadde3e9ff6d6c63366c9d550a245bf437 /tensorflow/python/ops/rnn_cell_impl.py | |
parent | dcbf6c972d7b4203735bca04f4d33d575ef7b22b (diff) |
Convert LSTMFusedBlockCell to a plain Layer; it is not really an RNNCell.
This allows us to revert a change to the public API for most RNNCells.
That breaking change was introduced yesterday (wherein scope argument had to be
passed by keyword arg).
PiperOrigin-RevId: 178930316
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r-- | tensorflow/python/ops/rnn_cell_impl.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 7c759d852c..7cb9f7762d 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -265,18 +265,18 @@ class _LayerRNNCell(RNNCell): `call` methods do not access Variables `tf.get_variable`. """ - def __call__(self, inputs, *args, **kwargs): + def __call__(self, inputs, state, scope=None, *args, **kwargs): """Run this RNN cell on inputs, starting from the given state. Args: inputs: `2-D` tensor with shape `[batch_size, input_size]`. - *args: Additional positional arguments. - Usually composesed of `[state]`: if `self.state_size` is an integer, - this should be a `2-D Tensor` with shape - `[batch_size, self.state_size]`. Otherwise, if + state: if `self.state_size` is an integer, this should be a `2-D Tensor` + with shape `[batch_size, self.state_size]`. Otherwise, if `self.state_size` is a tuple of integers, this should be a tuple with shapes `[batch_size, s] for s in self.state_size`. - **kwargs: Additional keyword arguments. Common keys include `scope`. + scope: optional cell scope. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. Returns: A pair containing: @@ -288,7 +288,8 @@ class _LayerRNNCell(RNNCell): # Bypass RNNCell's variable capturing semantics for LayerRNNCell. # Instead, it is up to subclasses to provide a proper build # method. See the class docstring for more details. - return base_layer.Layer.__call__(self, inputs, *args, **kwargs) + return base_layer.Layer.__call__(self, inputs, state, scope=scope, + *args, **kwargs) class BasicRNNCell(_LayerRNNCell): |