diff options
author | 2018-07-20 12:55:42 -0700 | |
---|---|---|
committer | 2018-07-20 12:55:49 -0700 | |
commit | 50e11adf67d3bb79a653f423aac7fb00747951e8 (patch) | |
tree | 2191d6bc9973c72b35d5428a93f61dc9ddc27e02 /tensorflow/contrib/tensorrt/BUILD | |
parent | 9f8256a61fcd44eeef7c0bf41c9bb4fddc505ae0 (diff) | |
parent | a103552156432bcda7e29e5588e83c62d5154b88 (diff) |
Merge pull request #20774 from jjsjann123:test_pr
PiperOrigin-RevId: 205439601
Diffstat (limited to 'tensorflow/contrib/tensorrt/BUILD')
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 33 |
1 files changed, 28 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index a9378e9ad6..7999f718e3 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -11,7 +11,6 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "cuda_py_test", "tf_cc_test", "tf_copts", "tf_cuda_library", @@ -20,6 +19,7 @@ load( "tf_gen_op_libs", "tf_gen_op_wrapper_py", ) +load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") @@ -154,6 +154,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":tf_trt_integration_test_base", ":trt_convert_py", ":trt_ops_py", "//tensorflow/python:errors", @@ -349,15 +350,37 @@ tf_cuda_cc_test( ]), ) -cuda_py_test( +py_library( + name = "tf_trt_integration_test_base", + srcs = ["test/tf_trt_integration_test_base.py"], + deps = [ + ":trt_convert_py", + ":trt_ops_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +cuda_py_tests( name = "tf_trt_integration_test", - srcs = ["test/tf_trt_integration_test.py"], + srcs = [ + "test/base_test.py", + # "test/batch_matmul_test.py", + # "test/biasadd_matmul_test.py", + "test/binary_tensor_weight_broadcast_test.py", + "test/concatenation_test.py", + "test/const_broadcast_test.py", + "test/multi_connection_neighbor_engine_test.py", + "test/neighboring_engine_test.py", + "test/unary_test.py", + # "test/vgg_block_nchw_test.py", + # "test/vgg_block_test.py", + ], additional_deps = [ - ":init_py", + ":tf_trt_integration_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], - main = "test/tf_trt_integration_test.py", tags = [ "no_windows", "nomac", |