diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/vgg_block_test.py')
-rw-r--r-- | tensorflow/contrib/tensorrt/test/vgg_block_test.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py index 56bdf848ea..d7c165784b 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py @@ -38,15 +38,14 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): dtype = dtypes.float32 input_name = "input" input_dims = [5, 8, 8, 2] + output_name = "output" g = ops.Graph() with g.as_default(): x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) x, _, _ = nn_impl.fused_batch_norm( - x, - np.random.randn(2).astype(np.float32), - np.random.randn(2).astype(np.float32), - mean=np.random.randn(2).astype(np.float32), - variance=np.random.randn(2).astype(np.float32), + x, [1.0, 1.0], [0.0, 0.0], + mean=[0.5, 0.5], + variance=[1.0, 1.0], is_training=False) e = constant_op.constant( np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype) @@ -58,15 +57,17 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): idty = array_ops.identity(relu, "ID") v = nn_ops.max_pool( idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - array_ops.squeeze(v, name="output") + array_ops.squeeze(v, name=output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines=["my_trt_op_0"], - expected_output_dims=(5, 2, 2, 6), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + output_names=[output_name], + expected_output_dims=[(5, 2, 2, 6)]) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["my_trt_op_0"] if __name__ == "__main__": |