diff options
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r-- | tensorflow/python/ops/rnn_cell_impl.py | 61 |
1 files changed, 5 insertions, 56 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 82a044a0d4..42806ba6ec 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -47,7 +47,6 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -55,16 +54,6 @@ from tensorflow.python.util.tf_export import tf_export _BIAS_VARIABLE_NAME = "bias" _WEIGHTS_VARIABLE_NAME = "kernel" - -# TODO(jblespiau): Remove this function when we are sure there are no longer -# any usage (even if protected, it is being used). Prefer assert_like_rnncell. -def _like_rnncell(cell): - """Checks that a given object is an RNNCell by using duck typing.""" - conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"), - hasattr(cell, "zero_state"), callable(cell)] - return all(conditions) - - # This can be used with self.assertRaisesRegexp for assert_like_rnncell. ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" @@ -1272,6 +1261,11 @@ class MultiRNNCell(RNNCell): raise TypeError( "cells must be a list or tuple, but saw: %s." % cells) + if len(set([id(cell) for cell in cells])) < len(cells): + logging.log_first_n(logging.WARN, + "At least two cells provided to MultiRNNCell " + "are the same object and will share weights.", 1) + self._cells = cells for cell_number, cell in enumerate(self._cells): # Add Checkpointable dependencies on these cells so their variables get @@ -1330,48 +1324,3 @@ class MultiRNNCell(RNNCell): array_ops.concat(new_states, 1)) return cur_inp, new_states - - -class _SlimRNNCell(RNNCell, checkpointable_tracking.NotCheckpointable): - """A simple wrapper for slim.rnn_cells.""" - - def __init__(self, cell_fn): - """Create a SlimRNNCell from a cell_fn. - - Args: - cell_fn: a function which takes (inputs, state, scope) and produces the - outputs and the new_state. Additionally when called with inputs=None and - state=None it should return (initial_outputs, initial_state). - - Raises: - TypeError: if cell_fn is not callable - ValueError: if cell_fn cannot produce a valid initial state. - """ - if not callable(cell_fn): - raise TypeError("cell_fn %s needs to be callable", cell_fn) - self._cell_fn = cell_fn - self._cell_name = cell_fn.func.__name__ - init_output, init_state = self._cell_fn(None, None) - output_shape = init_output.get_shape() - state_shape = init_state.get_shape() - self._output_size = output_shape.with_rank(2)[1].value - self._state_size = state_shape.with_rank(2)[1].value - if self._output_size is None: - raise ValueError("Initial output created by %s has invalid shape %s" % - (self._cell_name, output_shape)) - if self._state_size is None: - raise ValueError("Initial state created by %s has invalid shape %s" % - (self._cell_name, state_shape)) - - @property - def state_size(self): - return self._state_size - - @property - def output_size(self): - return self._output_size - - def __call__(self, inputs, state, scope=None): - scope = scope or self._cell_name - output, state = self._cell_fn(inputs, state, scope=scope) - return output, state |