diff options
author | Jianwei Xie <xiejw@google.com> | 2016-12-07 16:54:18 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-07 17:03:23 -0800 |
commit | 5e96fa748d0d4c15b262c57bbbbcd417b1f71a10 (patch) | |
tree | d4434bf11e800a4a05aa0ae2c3b896c4c4a58638 /tensorflow/contrib/grid_rnn | |
parent | 3e1124995ca297c9bd1da943e5da8f13609a8e15 (diff) |
Import tf.contrib.rnn for cells instead of tf.nn.rnn_cell.
Change: 141376440
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py index 502eae5d20..6247cad380 100644 --- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py +++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py @@ -24,11 +24,11 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import rnn_cell from tensorflow.contrib import layers +from tensorflow.contrib import rnn -class GridRNNCell(rnn_cell.RNNCell): +class GridRNNCell(rnn.RNNCell): """Grid recurrent cell. This implementation is based on: @@ -94,11 +94,11 @@ class GridRNNCell(rnn_cell.RNNCell): cell_input_size = (self._config.num_dims - 1) * num_units if cell_fn is None: - self._cell = rnn_cell.LSTMCell( + self._cell = rnn.LSTMCell( num_units=num_units, input_size=cell_input_size, state_is_tuple=False) else: self._cell = cell_fn(num_units, cell_input_size) - if not isinstance(self._cell, rnn_cell.RNNCell): + if not isinstance(self._cell, rnn.RNNCell): raise ValueError('cell_fn must return an object of type RNNCell') @property @@ -213,7 +213,7 @@ class Grid1BasicRNNCell(GridRNNCell): super(Grid1BasicRNNCell, self).__init__( num_units=num_units, num_dims=1, input_dims=0, output_dims=0, priority_dims=0, tied=False, - cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i)) + cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i)) class Grid2BasicRNNCell(GridRNNCell): @@ -231,7 +231,7 @@ class Grid2BasicRNNCell(GridRNNCell): num_units=num_units, num_dims=2, input_dims=0, output_dims=0, priority_dims=0, tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i), + cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i), non_recurrent_fn=non_recurrent_fn) @@ -242,7 +242,7 @@ class Grid1BasicLSTMCell(GridRNNCell): super(Grid1BasicLSTMCell, self).__init__( num_units=num_units, num_dims=1, input_dims=0, output_dims=0, priority_dims=0, tied=False, - cell_fn=lambda n, i: rnn_cell.BasicLSTMCell( + cell_fn=lambda n, i: rnn.BasicLSTMCell( num_units=n, forget_bias=forget_bias, input_size=i, state_is_tuple=False)) @@ -267,7 +267,7 @@ class Grid2BasicLSTMCell(GridRNNCell): num_units=num_units, num_dims=2, input_dims=0, output_dims=0, priority_dims=0, tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn_cell.BasicLSTMCell( + cell_fn=lambda n, i: rnn.BasicLSTMCell( num_units=n, forget_bias=forget_bias, input_size=i, state_is_tuple=False), non_recurrent_fn=non_recurrent_fn) @@ -284,7 +284,7 @@ class Grid1LSTMCell(GridRNNCell): super(Grid1LSTMCell, self).__init__( num_units=num_units, num_dims=1, input_dims=0, output_dims=0, priority_dims=0, - cell_fn=lambda n, i: rnn_cell.LSTMCell( + cell_fn=lambda n, i: rnn.LSTMCell( num_units=n, input_size=i, use_peepholes=use_peepholes, forget_bias=forget_bias, state_is_tuple=False)) @@ -308,7 +308,7 @@ class Grid2LSTMCell(GridRNNCell): num_units=num_units, num_dims=2, input_dims=0, output_dims=0, priority_dims=0, tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn_cell.LSTMCell( + cell_fn=lambda n, i: rnn.LSTMCell( num_units=n, input_size=i, forget_bias=forget_bias, use_peepholes=use_peepholes, state_is_tuple=False), non_recurrent_fn=non_recurrent_fn) @@ -334,7 +334,7 @@ class Grid3LSTMCell(GridRNNCell): num_units=num_units, num_dims=3, input_dims=0, output_dims=0, priority_dims=0, tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn_cell.LSTMCell( + cell_fn=lambda n, i: rnn.LSTMCell( num_units=n, input_size=i, forget_bias=forget_bias, use_peepholes=use_peepholes, state_is_tuple=False), non_recurrent_fn=non_recurrent_fn) @@ -354,7 +354,7 @@ class Grid2GRUCell(GridRNNCell): num_units=num_units, num_dims=2, input_dims=0, output_dims=0, priority_dims=0, tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i), + cell_fn=lambda n, i: rnn.GRUCell(num_units=n, input_size=i), non_recurrent_fn=non_recurrent_fn) |