aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/test/vgg_block_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/vgg_block_test.py')
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_test.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
index 56bdf848ea..d7c165784b 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
@@ -38,15 +38,14 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [5, 8, 8, 2]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
x, _, _ = nn_impl.fused_batch_norm(
- x,
- np.random.randn(2).astype(np.float32),
- np.random.randn(2).astype(np.float32),
- mean=np.random.randn(2).astype(np.float32),
- variance=np.random.randn(2).astype(np.float32),
+ x, [1.0, 1.0], [0.0, 0.0],
+ mean=[0.5, 0.5],
+ variance=[1.0, 1.0],
is_training=False)
e = constant_op.constant(
np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype)
@@ -58,15 +57,17 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
idty = array_ops.identity(relu, "ID")
v = nn_ops.max_pool(
idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
- array_ops.squeeze(v, name="output")
+ array_ops.squeeze(v, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(5, 2, 2, 6),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(5, 2, 2, 6)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
if __name__ == "__main__":