aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py165
1 files changed, 156 insertions, 9 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 0d9285ccb8..8090743e6c 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import functools
+import itertools
import sys
# TODO: #6568 Remove this hack that makes dlopen() not crash.
@@ -33,9 +34,14 @@ import numpy as np
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@@ -43,10 +49,41 @@ from tensorflow.python.ops import rnn
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
+from tensorflow.python.util import nest
# pylint: enable=protected-access
+def _CreateMultiLSTMCellOps(batch_size, num_units, input_depth,
+ num_layers, max_time, compiled):
+ with variable_scope.variable_scope(
+ "root",
+ initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)):
+ inputs = random_ops.random_uniform(
+ (max_time, batch_size, input_depth), seed=1)
+ rnn_cell = core_rnn_cell_impl.MultiRNNCell(
+ [core_rnn_cell_impl.LSTMCell(num_units, compiled=compiled)
+ for _ in range(num_layers)])
+ initial_state = rnn_cell.zero_state(
+ batch_size=batch_size, dtype=dtypes.float32)
+ outputs, final_state = rnn.dynamic_rnn(
+ cell=rnn_cell, inputs=inputs, initial_state=initial_state,
+ time_major=True)
+ flat_final_state = nest.flatten(final_state)
+ trainable_variables = variables_lib.trainable_variables()
+ outputs_grad = gradients_impl.gradients(
+ [outputs],
+ trainable_variables + [inputs] + nest.flatten(initial_state))
+ final_state_grad = gradients_impl.gradients(
+ flat_final_state,
+ trainable_variables + [inputs] + nest.flatten(initial_state))
+
+ return {"outputs": outputs,
+ "final_state": flat_final_state,
+ "outputs_grad": outputs_grad,
+ "final_state_grad": final_state_grad}
+
+
class RNNCellTest(test.TestCase):
def testLinear(self):
@@ -117,8 +154,8 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 8])
g, out_m = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.BasicLSTMCell(
- 2, state_is_tuple=False)] * 2,
+ [core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ for _ in range(2)],
state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
@@ -165,7 +202,8 @@ class RNNCellTest(test.TestCase):
m0 = (array_ops.zeros([1, 2]),) * 2
m1 = (array_ops.zeros([1, 2]),) * 2
cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.BasicLSTMCell(2)] * 2, state_is_tuple=True)
+ [core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
+ state_is_tuple=True)
self.assertTrue(isinstance(cell.state_size, tuple))
self.assertTrue(
isinstance(cell.state_size[0], core_rnn_cell_impl.LSTMStateTuple))
@@ -197,8 +235,8 @@ class RNNCellTest(test.TestCase):
m0 = array_ops.zeros([1, 4])
m1 = array_ops.zeros([1, 4])
cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.BasicLSTMCell(
- 2, state_is_tuple=False)] * 2,
+ [core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ for _ in range(2)],
state_is_tuple=True)
g, (out_m0, out_m1) = cell(x, (m0, m1))
sess.run([variables_lib.global_variables_initializer()])
@@ -407,7 +445,8 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 4])
_, ml = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=False)(x, m)
+ [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
+ state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(ml, {
x.name: np.array([[1., 1.]]),
@@ -416,6 +455,48 @@ class RNNCellTest(test.TestCase):
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])
+ def testMultiRNNCellWithLSTMCellAndXLA(self):
+ # TODO(b/34735319): Don't run this test if XLA is not available.
+ batch_size = 16
+ num_units = 32
+ input_depth = 12
+ num_layers = 2
+ max_time = 20
+
+ random_seed.set_random_seed(1234)
+ with self.test_session(graph=ops.Graph()) as sess:
+ xla_ops = _CreateMultiLSTMCellOps(
+ batch_size=batch_size, num_units=num_units,
+ input_depth=input_depth, num_layers=num_layers,
+ max_time=max_time,
+ compiled=True)
+ sess.run([variables_lib.global_variables_initializer()])
+ xla_results = sess.run(xla_ops)
+
+ random_seed.set_random_seed(1234)
+ with self.test_session(graph=ops.Graph()) as sess:
+ non_xla_ops = _CreateMultiLSTMCellOps(
+ batch_size=batch_size, num_units=num_units,
+ input_depth=input_depth, num_layers=num_layers,
+ max_time=max_time,
+ compiled=False)
+ sess.run([variables_lib.global_variables_initializer()])
+ non_xla_results = sess.run(non_xla_ops)
+
+ self.assertAllClose(non_xla_results["outputs"], xla_results["outputs"])
+
+ for xla_value, non_xla_value in zip(
+ xla_results["final_state"], non_xla_results["final_state"]):
+ self.assertAllClose(xla_value, non_xla_value)
+
+ for xla_g, non_xla_g in zip(
+ xla_results["outputs_grad"], non_xla_results["outputs_grad"]):
+ self.assertAllClose(xla_g, non_xla_g)
+
+ for xla_g, non_xla_g in zip(
+ xla_results["final_state_grad"], non_xla_results["final_state_grad"]):
+ self.assertAllClose(xla_g, non_xla_g)
+
def testMultiRNNCellWithStateTuple(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
@@ -427,11 +508,12 @@ class RNNCellTest(test.TestCase):
# Test incorrectness of state
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(2)] * 2,
+ [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
state_is_tuple=True)(x, m_bad)
_, ml = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=True)(x, m_good)
+ [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
+ state_is_tuple=True)(x, m_good)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(ml, {
@@ -490,7 +572,7 @@ class SlimRNNCellTest(test.TestCase):
self.assertAllClose(res[1], res[3])
-def basic_rnn_cell(inputs, state, num_units, scope=None):
+def basic_rnn_cell(inputs, state, num_units, scope=None): # pylint: disable=invalid-name
if state is None:
if inputs is not None:
batch_size = inputs.get_shape()[0]
@@ -512,5 +594,70 @@ def basic_rnn_cell(inputs, state, num_units, scope=None):
return output, output
+class BenchmarkLSTMCellXLA(test.Benchmark):
+
+ def benchmarkDynamicRNNWithMultiLSTMCell(self):
+ num_layers = 3
+ max_time = 50
+ print("benchmarkDynamicRNNWithMultiLSTMCell")
+ print("\t" +
+ "\t".join(["inter_th", "intra_th",
+ "batch_size", "num_units", "input_depth", "device",
+ "compiled", "wall_time"]))
+
+ warmup_run = True
+ for (threads,
+ device,
+ num_units,
+ batch_size,
+ input_depth,
+ compiled) in itertools.product(
+ [{"inter": 0, "intra": 0}, {"inter": 1, "intra": 4}],
+ ["cpu", "gpu"],
+ [32, 512],
+ [1, 32, 256],
+ [32, 512],
+ [False, True]):
+ if threads["inter"] != 0:
+ # We only care about testing inter/intra op limitations on
+ # CPU with small batch size, to mimic embedded devices.
+ if device != "cpu" or batch_size != 1:
+ continue
+ if device == "cpu" and batch_size > 32:
+ continue
+ random_seed.set_random_seed(1234)
+ config = config_pb2.ConfigProto(
+ inter_op_parallelism_threads=threads["inter"],
+ intra_op_parallelism_threads=threads["intra"],
+ allow_soft_placement=False)
+ with session.Session(config=config, graph=ops.Graph()) as sess:
+ with ops.device("/%s:0" % device):
+ ops_dict = _CreateMultiLSTMCellOps(
+ batch_size=batch_size, num_units=num_units,
+ input_depth=input_depth, num_layers=num_layers,
+ max_time=max_time,
+ compiled=compiled)
+ sess.run([variables_lib.global_variables_initializer()])
+ all_ops = nest.flatten(ops_dict.values())
+ all_ops_group = control_flow_ops.group(*all_ops)
+ name_suffix = (
+ "inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
+ "_device_%s_xla_%s" % (
+ threads["inter"], threads["intra"],
+ batch_size, num_units, input_depth, device, compiled))
+ if warmup_run:
+ self.run_op_benchmark(
+ sess, all_ops_group, min_iters=30, name="ignore_warmup")
+ warmup_run = False
+ benchmark_results = self.run_op_benchmark(
+ sess, all_ops_group, min_iters=30,
+ name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix)
+ print("\t" +
+ "\t".join(["%s" % x for x in [
+ threads["inter"], threads["intra"],
+ batch_size, num_units, input_depth, device, compiled,
+ benchmark_results["wall_time"]]]))
+
+
if __name__ == "__main__":
test.main()