aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/neighboring_engine_test.py')
-rw-r--r--tensorflow/contrib/tensorrt/test/neighboring_engine_test.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
index 50265c0845..51c905a50b 100644
--- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
@@ -25,7 +25,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
@@ -51,15 +51,18 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
name="conv")
b = constant_op.constant(
np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype)
- t = conv * b
- e = gen_math_ops.tan(conv)
- t = t - e
+ t = math_ops.mul(conv, b, name="mul")
+ e = self.trt_incompatible_op(conv, name="incompatible")
+ t = math_ops.sub(t, e, name="sub")
array_ops.squeeze(t, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=2,
+ expected_engines={
+ "my_trt_op_0": ["bias", "mul", "sub"],
+ "my_trt_op_1": ["weights", "conv"]
+ },
expected_output_dims=(2, 4, 5, 4),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)