diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-05-22 17:32:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-22 17:36:11 -0700 |
commit | 827d2e4b9180db67853f60c125e548d83986b96c (patch) | |
tree | 1ccaf8f20bf678ec755330b488eb28946dbe38e6 /tensorflow/contrib/grid_rnn | |
parent | 95719e869c61c78a4b0ac0407e1fb04e60daca35 (diff) |
Move many of the "core" RNNCells and rnn functions back to TF core.
Unit test files will move in a followup PR. This is the big API change.
The old behavior (using tf.contrib.rnn....) will continue to work for
backwards compatibility.
PiperOrigin-RevId: 156809677
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py index 280271a42d..fed8a771cc 100644 --- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py +++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py @@ -21,11 +21,11 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.grid_rnn.python.ops import grid_rnn_cell -from tensorflow.contrib.rnn.python.ops import core_rnn from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import rnn from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -527,7 +527,7 @@ class GridRNNCellTest(test.TestCase): dtypes.float32, shape=(batch_size, input_size)) ] - outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) + outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) self.assertEqual(state[0].c.get_shape(), (batch_size, 2)) @@ -569,7 +569,7 @@ class GridRNNCellTest(test.TestCase): array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] - outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) + outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) self.assertEqual(state[0].c.get_shape(), (batch_size, 2)) @@ -609,7 +609,7 @@ class GridRNNCellTest(test.TestCase): array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] - outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) + outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) self.assertEqual(state[0].c.get_shape(), (batch_size, 2)) @@ -652,7 +652,7 @@ class GridRNNCellTest(test.TestCase): dtypes.float32, shape=(batch_size, input_size)) ] + (max_length - 1) * [array_ops.zeros([batch_size, input_size])]) - outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) + outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) self.assertEqual(state[0].c.get_shape(), (batch_size, 2)) @@ -690,7 +690,7 @@ class GridRNNCellTest(test.TestCase): array_ops.placeholder(dtypes.float32, shape=(None, input_size)) ] - outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) + outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) |