aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-08-09 12:54:15 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-08-09 12:54:15 -0700
commit8945e0f1fb1cdc026ce7cf91b339b0b6a21f6dc6 (patch)
tree2eda36f799aa56214035a5f85e3bdfade3b509f9 /tensorflow/contrib/tensorrt
parent728422d1eee62374b3221676a1826660473897bc (diff)
Fix rank_two_test.
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/test/rank_two_test.py65
1 files changed, 36 insertions, 29 deletions
diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py
index a0c18da265..fbed1ac4e8 100644
--- a/tensorflow/contrib/tensorrt/test/rank_two_test.py
+++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -33,40 +34,46 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase):
def GetParams(self):
"""Test for rank 2 input in TF-TRT."""
- dtype = dtypes.float32
- input_name = "input"
- input_dims = [12, 5]
- input2_name = "input2"
- input2_dims = [12, 5, 2, 2]
+ input_names = ["input", "input2"]
+ input_dims = [[12, 5], [12, 5, 2, 2]]
g = ops.Graph()
with g.as_default():
- # path 1 with rank 2 input
- x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
- q = x + 1.0
- q = math_ops.abs(q)
- q = q + 2.2
- q = math_ops.abs(q)
- q = q + 3.0
- q = array_ops.expand_dims(q, -1)
- q = array_ops.expand_dims(q, -1)
- a = gen_math_ops.reciprocal(q)
- # path 2 with rank 4 input
- x = array_ops.placeholder(dtype=dtype, shape=input2_dims, name=input2_name)
- q = x + 1.0
- q = math_ops.abs(q)
- q = q + 2.2
- q = math_ops.abs(q)
- q = q + 3.0
- b = gen_math_ops.reciprocal(q)
- # combine path 1 & 2
- q = a + b
+ # Path 1 with rank 2 input
+ outputs = []
+ for i in range(2):
+ x = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims[i], name=input_names[i])
+ c = constant_op.constant(1.0, name="c%d_1" % i)
+ q = math_ops.add(x, c, name="add%d_1" % i)
+ q = math_ops.abs(q, name="abs%d_1" % i)
+ c = constant_op.constant(2.2, name="c%d_2" % i)
+ q = math_ops.add(q, c, name="add%d_2" % i)
+ q = math_ops.abs(q, name="abs%d_2" % i)
+ c = constant_op.constant(3.0, name="c%d_3" % i)
+ q = math_ops.add(q, c, name="add%d_3" % i)
+ if i == 0:
+ for j in range(2):
+ q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j))
+ q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i)
+ outputs.append(q)
+ # Combine path 1 & 2
+ q = math_ops.add(outputs[0], outputs[1], name="add")
array_ops.squeeze(q, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
- input_names=[input_name, input2_name],
- input_dims=[input_dims, input2_dims],
- num_expected_engines=2,
- expected_output_dims=(12, 5, 2, 2),
+ input_names=input_names,
+ input_dims=input_dims,
+ expected_engines={
+ "my_trt_op_0": [
+ "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1",
+ "abs0_2"
+ ],
+ "my_trt_op_1": [
+ "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3",
+ "abs1_1", "abs1_2", "reciprocal0", "reciprocal1"
+ ],
+ },
+ expected_output_dims=tuple(input_dims[1]),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)