aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/grid_rnn
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2016-12-07 16:54:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-07 17:03:23 -0800
commit5e96fa748d0d4c15b262c57bbbbcd417b1f71a10 (patch)
treed4434bf11e800a4a05aa0ae2c3b896c4c4a58638 /tensorflow/contrib/grid_rnn
parent3e1124995ca297c9bd1da943e5da8f13609a8e15 (diff)
Import tf.contrib.rnn for cells instead of tf.nn.rnn_cell.
Change: 141376440
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r--tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py24
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
index 502eae5d20..6247cad380 100644
--- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
+++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
@@ -24,11 +24,11 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope as vs
-from tensorflow.python.ops import rnn_cell
from tensorflow.contrib import layers
+from tensorflow.contrib import rnn
-class GridRNNCell(rnn_cell.RNNCell):
+class GridRNNCell(rnn.RNNCell):
"""Grid recurrent cell.
This implementation is based on:
@@ -94,11 +94,11 @@ class GridRNNCell(rnn_cell.RNNCell):
cell_input_size = (self._config.num_dims - 1) * num_units
if cell_fn is None:
- self._cell = rnn_cell.LSTMCell(
+ self._cell = rnn.LSTMCell(
num_units=num_units, input_size=cell_input_size, state_is_tuple=False)
else:
self._cell = cell_fn(num_units, cell_input_size)
- if not isinstance(self._cell, rnn_cell.RNNCell):
+ if not isinstance(self._cell, rnn.RNNCell):
raise ValueError('cell_fn must return an object of type RNNCell')
@property
@@ -213,7 +213,7 @@ class Grid1BasicRNNCell(GridRNNCell):
super(Grid1BasicRNNCell, self).__init__(
num_units=num_units, num_dims=1,
input_dims=0, output_dims=0, priority_dims=0, tied=False,
- cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i))
+ cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i))
class Grid2BasicRNNCell(GridRNNCell):
@@ -231,7 +231,7 @@ class Grid2BasicRNNCell(GridRNNCell):
num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i),
+ cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i),
non_recurrent_fn=non_recurrent_fn)
@@ -242,7 +242,7 @@ class Grid1BasicLSTMCell(GridRNNCell):
super(Grid1BasicLSTMCell, self).__init__(
num_units=num_units, num_dims=1,
input_dims=0, output_dims=0, priority_dims=0, tied=False,
- cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
+ cell_fn=lambda n, i: rnn.BasicLSTMCell(
num_units=n,
forget_bias=forget_bias, input_size=i,
state_is_tuple=False))
@@ -267,7 +267,7 @@ class Grid2BasicLSTMCell(GridRNNCell):
num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
+ cell_fn=lambda n, i: rnn.BasicLSTMCell(
num_units=n, forget_bias=forget_bias, input_size=i,
state_is_tuple=False),
non_recurrent_fn=non_recurrent_fn)
@@ -284,7 +284,7 @@ class Grid1LSTMCell(GridRNNCell):
super(Grid1LSTMCell, self).__init__(
num_units=num_units, num_dims=1,
input_dims=0, output_dims=0, priority_dims=0,
- cell_fn=lambda n, i: rnn_cell.LSTMCell(
+ cell_fn=lambda n, i: rnn.LSTMCell(
num_units=n, input_size=i, use_peepholes=use_peepholes,
forget_bias=forget_bias, state_is_tuple=False))
@@ -308,7 +308,7 @@ class Grid2LSTMCell(GridRNNCell):
num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn_cell.LSTMCell(
+ cell_fn=lambda n, i: rnn.LSTMCell(
num_units=n, input_size=i, forget_bias=forget_bias,
use_peepholes=use_peepholes, state_is_tuple=False),
non_recurrent_fn=non_recurrent_fn)
@@ -334,7 +334,7 @@ class Grid3LSTMCell(GridRNNCell):
num_units=num_units, num_dims=3,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn_cell.LSTMCell(
+ cell_fn=lambda n, i: rnn.LSTMCell(
num_units=n, input_size=i, forget_bias=forget_bias,
use_peepholes=use_peepholes, state_is_tuple=False),
non_recurrent_fn=non_recurrent_fn)
@@ -354,7 +354,7 @@ class Grid2GRUCell(GridRNNCell):
num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i),
+ cell_fn=lambda n, i: rnn.GRUCell(num_units=n, input_size=i),
non_recurrent_fn=non_recurrent_fn)