aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn_cell_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py61
1 files changed, 5 insertions, 56 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 82a044a0d4..42806ba6ec 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -47,7 +47,6 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -55,16 +54,6 @@ from tensorflow.python.util.tf_export import tf_export
_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"
-
-# TODO(jblespiau): Remove this function when we are sure there are no longer
-# any usage (even if protected, it is being used). Prefer assert_like_rnncell.
-def _like_rnncell(cell):
- """Checks that a given object is an RNNCell by using duck typing."""
- conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
- hasattr(cell, "zero_state"), callable(cell)]
- return all(conditions)
-
-
# This can be used with self.assertRaisesRegexp for assert_like_rnncell.
ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell"
@@ -1272,6 +1261,11 @@ class MultiRNNCell(RNNCell):
raise TypeError(
"cells must be a list or tuple, but saw: %s." % cells)
+ if len(set([id(cell) for cell in cells])) < len(cells):
+ logging.log_first_n(logging.WARN,
+ "At least two cells provided to MultiRNNCell "
+ "are the same object and will share weights.", 1)
+
self._cells = cells
for cell_number, cell in enumerate(self._cells):
# Add Checkpointable dependencies on these cells so their variables get
@@ -1330,48 +1324,3 @@ class MultiRNNCell(RNNCell):
array_ops.concat(new_states, 1))
return cur_inp, new_states
-
-
-class _SlimRNNCell(RNNCell, checkpointable_tracking.NotCheckpointable):
- """A simple wrapper for slim.rnn_cells."""
-
- def __init__(self, cell_fn):
- """Create a SlimRNNCell from a cell_fn.
-
- Args:
- cell_fn: a function which takes (inputs, state, scope) and produces the
- outputs and the new_state. Additionally when called with inputs=None and
- state=None it should return (initial_outputs, initial_state).
-
- Raises:
- TypeError: if cell_fn is not callable
- ValueError: if cell_fn cannot produce a valid initial state.
- """
- if not callable(cell_fn):
- raise TypeError("cell_fn %s needs to be callable", cell_fn)
- self._cell_fn = cell_fn
- self._cell_name = cell_fn.func.__name__
- init_output, init_state = self._cell_fn(None, None)
- output_shape = init_output.get_shape()
- state_shape = init_state.get_shape()
- self._output_size = output_shape.with_rank(2)[1].value
- self._state_size = state_shape.with_rank(2)[1].value
- if self._output_size is None:
- raise ValueError("Initial output created by %s has invalid shape %s" %
- (self._cell_name, output_shape))
- if self._state_size is None:
- raise ValueError("Initial state created by %s has invalid shape %s" %
- (self._cell_name, state_shape))
-
- @property
- def state_size(self):
- return self._state_size
-
- @property
- def output_size(self):
- return self._output_size
-
- def __call__(self, inputs, state, scope=None):
- scope = scope or self._cell_name
- output, state = self._cell_fn(inputs, state, scope=scope)
- return output, state