aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py161
1 files changed, 115 insertions, 46 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 86f1e27abd..85f0f8ced9 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
import os
import numpy as np
@@ -35,7 +34,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
@@ -117,6 +115,27 @@ class RNNCellTest(test.TestCase):
})
self.assertEqual(res[0].shape, (1, 2))
+ def testIndRNNCell(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.IndRNNCell(2)
+ g, _ = cell(x, m)
+ self.assertEqual([
+ "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
+ ], [v.name for v in cell.trainable_variables])
+ self.assertFalse(cell.non_trainable_variables)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ self.assertEqual(res[0].shape, (1, 2))
+
def testGRUCell(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
@@ -145,6 +164,34 @@ class RNNCellTest(test.TestCase):
# Smoke test
self.assertAllClose(res[0], [[0.156736, 0.156736]])
+ def testIndyGRUCell(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.185265, 0.17704]])
+ with variable_scope.variable_scope(
+ "other", initializer=init_ops.constant_initializer(0.5)):
+ # Test IndyGRUCell with input_size != num_units.
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.155127, 0.157328]])
+
def testSRUCell(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
@@ -345,6 +392,72 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1], expected_mem0)
self.assertAllClose(res[2], expected_mem1)
+ def testIndyLSTMCell(self):
+ for dtype in [dtypes.float16, dtypes.float32]:
+ np_dtype = dtype.as_numpy_dtype
+ with self.test_session(graph=ops.Graph()) as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2], dtype=dtype)
+ state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ cell = rnn_cell_impl.MultiRNNCell(
+ [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)])
+ self.assertEqual(cell.dtype, None)
+ self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name)
+ self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name)
+ cell.get_config() # Should not throw an error
+ g, (out_state_0, out_state_1) = cell(x, (state_0, state_1))
+ # Layer infers the input type.
+ self.assertEqual(cell.dtype, dtype.name)
+ expected_variable_names = [
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" %
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" %
+ rnn_cell_impl._BIAS_VARIABLE_NAME
+ ]
+ self.assertEqual(expected_variable_names,
+ [v.name for v in cell.trainable_variables])
+ self.assertFalse(cell.non_trainable_variables)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g, out_state_0, out_state_1], {
+ x.name: np.array([[1., 1.]]),
+ state_0[0].name: 0.1 * np.ones([1, 2]),
+ state_0[1].name: 0.1 * np.ones([1, 2]),
+ state_1[0].name: 0.1 * np.ones([1, 2]),
+ state_1[1].name: 0.1 * np.ones([1, 2]),
+ })
+ self.assertEqual(len(res), 3)
+ variables = variables_lib.global_variables()
+ self.assertEqual(expected_variable_names, [v.name for v in variables])
+ # Only check the range of outputs as this is just a smoke test.
+ self.assertAllInRange(res[0], -1.0, 1.0)
+ self.assertAllInRange(res[1], -1.0, 1.0)
+ self.assertAllInRange(res[2], -1.0, 1.0)
+ with variable_scope.variable_scope(
+ "other", initializer=init_ops.constant_initializer(0.5)):
+ # Test IndyLSTMCell with input_size != num_units.
+ x = array_ops.zeros([1, 3], dtype=dtype)
+ state = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g, out_state], {
+ x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
+ state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype),
+ state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype),
+ })
+ self.assertEqual(len(res), 2)
+
def testLSTMCell(self):
with self.test_session() as sess:
num_units = 8
@@ -935,50 +1048,6 @@ class DropoutWrapperTest(test.TestCase):
self.assertAllClose(res0[1].h, res1[1].h)
-class SlimRNNCellTest(test.TestCase):
-
- def testBasicRNNCell(self):
- with self.test_session() as sess:
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- x = array_ops.zeros([1, 2])
- m = array_ops.zeros([1, 2])
- my_cell = functools.partial(basic_rnn_cell, num_units=2)
- # pylint: disable=protected-access
- g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
- # pylint: enable=protected-access
- sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([g], {
- x.name: np.array([[1., 1.]]),
- m.name: np.array([[0.1, 0.1]])
- })
- self.assertEqual(res[0].shape, (1, 2))
-
- def testBasicRNNCellMatch(self):
- batch_size = 32
- input_size = 100
- num_units = 10
- with self.test_session() as sess:
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- inputs = random_ops.random_uniform((batch_size, input_size))
- _, initial_state = basic_rnn_cell(inputs, None, num_units)
- rnn_cell = rnn_cell_impl.BasicRNNCell(num_units)
- outputs, state = rnn_cell(inputs, initial_state)
- variable_scope.get_variable_scope().reuse_variables()
- my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
- # pylint: disable=protected-access
- slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
- # pylint: enable=protected-access
- slim_outputs, slim_state = slim_cell(inputs, initial_state)
- self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
- self.assertEqual(slim_state.get_shape(), state.get_shape())
- sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([slim_outputs, slim_state, outputs, state])
- self.assertAllClose(res[0], res[2])
- self.assertAllClose(res[1], res[3])
-
-
def basic_rnn_cell(inputs, state, num_units, scope=None):
if state is None:
if inputs is not None: