diff options
author | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-08-09 12:54:15 -0700 |
---|---|---|
committer | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-08-09 12:54:15 -0700 |
commit | 8945e0f1fb1cdc026ce7cf91b339b0b6a21f6dc6 (patch) | |
tree | 2eda36f799aa56214035a5f85e3bdfade3b509f9 /tensorflow/contrib/tensorrt | |
parent | 728422d1eee62374b3221676a1826660473897bc (diff) |
Fix rank_two_test.
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r-- | tensorflow/contrib/tensorrt/test/rank_two_test.py | 65 |
1 files changed, 36 insertions, 29 deletions
diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py index a0c18da265..fbed1ac4e8 100644 --- a/tensorflow/contrib/tensorrt/test/rank_two_test.py +++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -33,40 +34,46 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Test for rank 2 input in TF-TRT.""" - dtype = dtypes.float32 - input_name = "input" - input_dims = [12, 5] - input2_name = "input2" - input2_dims = [12, 5, 2, 2] + input_names = ["input", "input2"] + input_dims = [[12, 5], [12, 5, 2, 2]] g = ops.Graph() with g.as_default(): - # path 1 with rank 2 input - x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) - q = x + 1.0 - q = math_ops.abs(q) - q = q + 2.2 - q = math_ops.abs(q) - q = q + 3.0 - q = array_ops.expand_dims(q, -1) - q = array_ops.expand_dims(q, -1) - a = gen_math_ops.reciprocal(q) - # path 2 with rank 4 input - x = array_ops.placeholder(dtype=dtype, shape=input2_dims, name=input2_name) - q = x + 1.0 - q = math_ops.abs(q) - q = q + 2.2 - q = math_ops.abs(q) - q = q + 3.0 - b = gen_math_ops.reciprocal(q) - # combine path 1 & 2 - q = a + b + # Path 1 with rank 2 input + outputs = [] + for i in range(2): + x = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims[i], name=input_names[i]) + c = constant_op.constant(1.0, name="c%d_1" % i) + q = math_ops.add(x, c, name="add%d_1" % i) + q = math_ops.abs(q, name="abs%d_1" % i) + c = constant_op.constant(2.2, name="c%d_2" % i) + q = math_ops.add(q, c, name="add%d_2" % i) + q = math_ops.abs(q, name="abs%d_2" % i) + c = constant_op.constant(3.0, name="c%d_3" % i) + q = math_ops.add(q, c, name="add%d_3" % i) + if i == 0: + for j in range(2): + q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j)) + q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i) + outputs.append(q) + # Combine path 1 & 2 + q = math_ops.add(outputs[0], outputs[1], name="add") array_ops.squeeze(q, name=self.output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), - input_names=[input_name, input2_name], - input_dims=[input_dims, input2_dims], - num_expected_engines=2, - expected_output_dims=(12, 5, 2, 2), + input_names=input_names, + input_dims=input_dims, + expected_engines={ + "my_trt_op_0": [ + "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1", + "abs0_2" + ], + "my_trt_op_1": [ + "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3", + "abs1_1", "abs1_2", "reciprocal0", "reciprocal1" + ], + }, + expected_output_dims=tuple(input_dims[1]), allclose_atol=1.e-03, allclose_rtol=1.e-03) |