diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/neighboring_engine_test.py')
-rw-r--r-- | tensorflow/contrib/tensorrt/test/neighboring_engine_test.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py index 50265c0845..51c905a50b 100644 --- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -51,15 +51,18 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): name="conv") b = constant_op.constant( np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) - t = conv * b - e = gen_math_ops.tan(conv) - t = t - e + t = math_ops.mul(conv, b, name="mul") + e = self.trt_incompatible_op(conv, name="incompatible") + t = math_ops.sub(t, e, name="sub") array_ops.squeeze(t, name=self.output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - num_expected_engines=2, + expected_engines={ + "my_trt_op_0": ["bias", "mul", "sub"], + "my_trt_op_1": ["weights", "conv"] + }, expected_output_dims=(2, 4, 5, 4), allclose_atol=1.e-03, allclose_rtol=1.e-03) |