diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py | 165 |
1 files changed, 156 insertions, 9 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 0d9285ccb8..8090743e6c 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import functools +import itertools import sys # TODO: #6568 Remove this hack that makes dlopen() not crash. @@ -33,9 +34,14 @@ import numpy as np from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -43,10 +49,41 @@ from tensorflow.python.ops import rnn from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test +from tensorflow.python.util import nest # pylint: enable=protected-access +def _CreateMultiLSTMCellOps(batch_size, num_units, input_depth, + num_layers, max_time, compiled): + with variable_scope.variable_scope( + "root", + initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)): + inputs = random_ops.random_uniform( + (max_time, batch_size, input_depth), seed=1) + rnn_cell = core_rnn_cell_impl.MultiRNNCell( + [core_rnn_cell_impl.LSTMCell(num_units, compiled=compiled) + for _ in range(num_layers)]) + initial_state = rnn_cell.zero_state( + batch_size=batch_size, dtype=dtypes.float32) + outputs, final_state = rnn.dynamic_rnn( + cell=rnn_cell, inputs=inputs, initial_state=initial_state, + time_major=True) + flat_final_state = nest.flatten(final_state) + trainable_variables = variables_lib.trainable_variables() + outputs_grad = gradients_impl.gradients( + [outputs], + trainable_variables + [inputs] + nest.flatten(initial_state)) + final_state_grad = gradients_impl.gradients( + flat_final_state, + trainable_variables + [inputs] + nest.flatten(initial_state)) + + return {"outputs": outputs, + "final_state": flat_final_state, + "outputs_grad": outputs_grad, + "final_state_grad": final_state_grad} + + class RNNCellTest(test.TestCase): def testLinear(self): @@ -117,8 +154,8 @@ class RNNCellTest(test.TestCase): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 8]) g, out_m = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.BasicLSTMCell( - 2, state_is_tuple=False)] * 2, + [core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) + for _ in range(2)], state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run( @@ -165,7 +202,8 @@ class RNNCellTest(test.TestCase): m0 = (array_ops.zeros([1, 2]),) * 2 m1 = (array_ops.zeros([1, 2]),) * 2 cell = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.BasicLSTMCell(2)] * 2, state_is_tuple=True) + [core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)], + state_is_tuple=True) self.assertTrue(isinstance(cell.state_size, tuple)) self.assertTrue( isinstance(cell.state_size[0], core_rnn_cell_impl.LSTMStateTuple)) @@ -197,8 +235,8 @@ class RNNCellTest(test.TestCase): m0 = array_ops.zeros([1, 4]) m1 = array_ops.zeros([1, 4]) cell = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.BasicLSTMCell( - 2, state_is_tuple=False)] * 2, + [core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) + for _ in range(2)], state_is_tuple=True) g, (out_m0, out_m1) = cell(x, (m0, m1)) sess.run([variables_lib.global_variables_initializer()]) @@ -407,7 +445,8 @@ class RNNCellTest(test.TestCase): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 4]) _, ml = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=False)(x, m) + [core_rnn_cell_impl.GRUCell(2) for _ in range(2)], + state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run(ml, { x.name: np.array([[1., 1.]]), @@ -416,6 +455,48 @@ class RNNCellTest(test.TestCase): # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]]) + def testMultiRNNCellWithLSTMCellAndXLA(self): + # TODO(b/34735319): Don't run this test if XLA is not available. + batch_size = 16 + num_units = 32 + input_depth = 12 + num_layers = 2 + max_time = 20 + + random_seed.set_random_seed(1234) + with self.test_session(graph=ops.Graph()) as sess: + xla_ops = _CreateMultiLSTMCellOps( + batch_size=batch_size, num_units=num_units, + input_depth=input_depth, num_layers=num_layers, + max_time=max_time, + compiled=True) + sess.run([variables_lib.global_variables_initializer()]) + xla_results = sess.run(xla_ops) + + random_seed.set_random_seed(1234) + with self.test_session(graph=ops.Graph()) as sess: + non_xla_ops = _CreateMultiLSTMCellOps( + batch_size=batch_size, num_units=num_units, + input_depth=input_depth, num_layers=num_layers, + max_time=max_time, + compiled=False) + sess.run([variables_lib.global_variables_initializer()]) + non_xla_results = sess.run(non_xla_ops) + + self.assertAllClose(non_xla_results["outputs"], xla_results["outputs"]) + + for xla_value, non_xla_value in zip( + xla_results["final_state"], non_xla_results["final_state"]): + self.assertAllClose(xla_value, non_xla_value) + + for xla_g, non_xla_g in zip( + xla_results["outputs_grad"], non_xla_results["outputs_grad"]): + self.assertAllClose(xla_g, non_xla_g) + + for xla_g, non_xla_g in zip( + xla_results["final_state_grad"], non_xla_results["final_state_grad"]): + self.assertAllClose(xla_g, non_xla_g) + def testMultiRNNCellWithStateTuple(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -427,11 +508,12 @@ class RNNCellTest(test.TestCase): # Test incorrectness of state with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.GRUCell(2)] * 2, + [core_rnn_cell_impl.GRUCell(2) for _ in range(2)], state_is_tuple=True)(x, m_bad) _, ml = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=True)(x, m_good) + [core_rnn_cell_impl.GRUCell(2) for _ in range(2)], + state_is_tuple=True)(x, m_good) sess.run([variables_lib.global_variables_initializer()]) res = sess.run(ml, { @@ -490,7 +572,7 @@ class SlimRNNCellTest(test.TestCase): self.assertAllClose(res[1], res[3]) -def basic_rnn_cell(inputs, state, num_units, scope=None): +def basic_rnn_cell(inputs, state, num_units, scope=None): # pylint: disable=invalid-name if state is None: if inputs is not None: batch_size = inputs.get_shape()[0] @@ -512,5 +594,70 @@ def basic_rnn_cell(inputs, state, num_units, scope=None): return output, output +class BenchmarkLSTMCellXLA(test.Benchmark): + + def benchmarkDynamicRNNWithMultiLSTMCell(self): + num_layers = 3 + max_time = 50 + print("benchmarkDynamicRNNWithMultiLSTMCell") + print("\t" + + "\t".join(["inter_th", "intra_th", + "batch_size", "num_units", "input_depth", "device", + "compiled", "wall_time"])) + + warmup_run = True + for (threads, + device, + num_units, + batch_size, + input_depth, + compiled) in itertools.product( + [{"inter": 0, "intra": 0}, {"inter": 1, "intra": 4}], + ["cpu", "gpu"], + [32, 512], + [1, 32, 256], + [32, 512], + [False, True]): + if threads["inter"] != 0: + # We only care about testing inter/intra op limitations on + # CPU with small batch size, to mimic embedded devices. + if device != "cpu" or batch_size != 1: + continue + if device == "cpu" and batch_size > 32: + continue + random_seed.set_random_seed(1234) + config = config_pb2.ConfigProto( + inter_op_parallelism_threads=threads["inter"], + intra_op_parallelism_threads=threads["intra"], + allow_soft_placement=False) + with session.Session(config=config, graph=ops.Graph()) as sess: + with ops.device("/%s:0" % device): + ops_dict = _CreateMultiLSTMCellOps( + batch_size=batch_size, num_units=num_units, + input_depth=input_depth, num_layers=num_layers, + max_time=max_time, + compiled=compiled) + sess.run([variables_lib.global_variables_initializer()]) + all_ops = nest.flatten(ops_dict.values()) + all_ops_group = control_flow_ops.group(*all_ops) + name_suffix = ( + "inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d" + "_device_%s_xla_%s" % ( + threads["inter"], threads["intra"], + batch_size, num_units, input_depth, device, compiled)) + if warmup_run: + self.run_op_benchmark( + sess, all_ops_group, min_iters=30, name="ignore_warmup") + warmup_run = False + benchmark_results = self.run_op_benchmark( + sess, all_ops_group, min_iters=30, + name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix) + print("\t" + + "\t".join(["%s" % x for x in [ + threads["inter"], threads["intra"], + batch_size, num_units, input_depth, device, compiled, + benchmark_results["wall_time"]]])) + + if __name__ == "__main__": test.main() |