aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/grid_rnn
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-05-22 17:32:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-22 17:36:11 -0700
commit827d2e4b9180db67853f60c125e548d83986b96c (patch)
tree1ccaf8f20bf678ec755330b488eb28946dbe38e6 /tensorflow/contrib/grid_rnn
parent95719e869c61c78a4b0ac0407e1fb04e60daca35 (diff)
Move many of the "core" RNNCells and rnn functions back to TF core.
Unit test files will move in a followup PR. This is the big API change. The old behavior (using tf.contrib.rnn....) will continue to work for backwards compatibility. PiperOrigin-RevId: 156809677
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index 280271a42d..fed8a771cc 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -21,11 +21,11 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.grid_rnn.python.ops import grid_rnn_cell
-from tensorflow.contrib.rnn.python.ops import core_rnn
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -527,7 +527,7 @@ class GridRNNCellTest(test.TestCase):
dtypes.float32, shape=(batch_size, input_size))
]
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
@@ -569,7 +569,7 @@ class GridRNNCellTest(test.TestCase):
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
]
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
@@ -609,7 +609,7 @@ class GridRNNCellTest(test.TestCase):
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
]
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
@@ -652,7 +652,7 @@ class GridRNNCellTest(test.TestCase):
dtypes.float32, shape=(batch_size, input_size))
] + (max_length - 1) * [array_ops.zeros([batch_size, input_size])])
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
@@ -690,7 +690,7 @@ class GridRNNCellTest(test.TestCase):
array_ops.placeholder(dtypes.float32, shape=(None, input_size))
]
- outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+ outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))