aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/BUILD
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 12:55:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 12:55:49 -0700
commit50e11adf67d3bb79a653f423aac7fb00747951e8 (patch)
tree2191d6bc9973c72b35d5428a93f61dc9ddc27e02 /tensorflow/contrib/tensorrt/BUILD
parent9f8256a61fcd44eeef7c0bf41c9bb4fddc505ae0 (diff)
parenta103552156432bcda7e29e5588e83c62d5154b88 (diff)
Merge pull request #20774 from jjsjann123:test_pr
PiperOrigin-RevId: 205439601
Diffstat (limited to 'tensorflow/contrib/tensorrt/BUILD')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD33
1 files changed, 28 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index a9378e9ad6..7999f718e3 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -11,7 +11,6 @@ exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
- "cuda_py_test",
"tf_cc_test",
"tf_copts",
"tf_cuda_library",
@@ -20,6 +19,7 @@ load(
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
@@ -154,6 +154,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":tf_trt_integration_test_base",
":trt_convert_py",
":trt_ops_py",
"//tensorflow/python:errors",
@@ -349,15 +350,37 @@ tf_cuda_cc_test(
]),
)
-cuda_py_test(
+py_library(
+ name = "tf_trt_integration_test_base",
+ srcs = ["test/tf_trt_integration_test_base.py"],
+ deps = [
+ ":trt_convert_py",
+ ":trt_ops_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+cuda_py_tests(
name = "tf_trt_integration_test",
- srcs = ["test/tf_trt_integration_test.py"],
+ srcs = [
+ "test/base_test.py",
+ # "test/batch_matmul_test.py",
+ # "test/biasadd_matmul_test.py",
+ "test/binary_tensor_weight_broadcast_test.py",
+ "test/concatenation_test.py",
+ "test/const_broadcast_test.py",
+ "test/multi_connection_neighbor_engine_test.py",
+ "test/neighboring_engine_test.py",
+ "test/unary_test.py",
+ # "test/vgg_block_nchw_test.py",
+ # "test/vgg_block_test.py",
+ ],
additional_deps = [
- ":init_py",
+ ":tf_trt_integration_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
],
- main = "test/tf_trt_integration_test.py",
tags = [
"no_windows",
"nomac",