aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-07-16 14:48:20 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-07-16 14:48:20 -0700
commitbd6c04a86cd77d1b969d88fd243c80b7780b6db3 (patch)
tree53d959b0d598d6f1d40959c0b03101fa5f822b72 /tensorflow/contrib/tensorrt
parente3e434d966ba5f800ba73ca688a851aa878c5463 (diff)
Refactor tf_trt_integration_test so we can extend it to other graphs.
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD28
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py125
-rw-r--r--tensorflow/contrib/tensorrt/test/base_unit_test.py118
-rw-r--r--tensorflow/contrib/tensorrt/test/run_test.py184
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py347
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py293
-rw-r--r--tensorflow/contrib/tensorrt/test/unit_tests.py67
-rw-r--r--tensorflow/contrib/tensorrt/test/utilities.py30
8 files changed, 433 insertions, 759 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index d957ca0861..7aed241fd0 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -11,7 +11,6 @@ exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
- "py_test",
"tf_cc_test",
"tf_copts",
"tf_cuda_library",
@@ -20,6 +19,7 @@ load(
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
@@ -33,7 +33,6 @@ tf_cuda_cc_test(
size = "small",
srcs = ["tensorrt_test.cc"],
tags = [
- "manual",
"notap",
],
deps = [
@@ -311,7 +310,6 @@ tf_cuda_cc_test(
size = "small",
srcs = ["plugin/trt_plugin_factory_test.cc"],
tags = [
- "manual",
"notap",
],
deps = [
@@ -325,15 +323,9 @@ tf_cuda_cc_test(
]),
)
-py_test(
- name = "tf_trt_integration_test",
- srcs = ["test/tf_trt_integration_test.py"],
- main = "test/tf_trt_integration_test.py",
- srcs_version = "PY2AND3",
- tags = [
- "manual",
- "notap",
- ],
+py_library(
+ name = "tf_trt_integration_test_base",
+ srcs = ["test/tf_trt_integration_test_base.py"],
deps = [
":init_py",
"//tensorflow/python:client_testlib",
@@ -341,6 +333,17 @@ py_test(
],
)
+cuda_py_tests(
+ name = "tf_trt_integration_test",
+ srcs = ["test/base_test.py"],
+ additional_deps = [
+ ":tf_trt_integration_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ prefix = "integration_test",
+)
+
py_test(
name = "converter_unit_tests",
srcs = [
@@ -362,7 +365,6 @@ py_test(
main = "test/unit_tests.py",
srcs_version = "PY2AND3",
tags = [
- "manual",
"notap",
],
deps = [
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
new file mode 100644
index 0000000000..4b9e6d668f
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Basic tests for TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+
+
+# TODO(aaroey): test graph with different dtypes.
+def _GetSingleEngineGraphDef(dtype=dtypes.float32):
+ """Create a graph containing single segment."""
+ input_dims = [100, 24, 24, 2]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtype, shape=[None] + input_dims[1:], name=trt_test.INPUT_NAME)
+ with g.device("/GPU:0"):
+ conv_filter = constant_op.constant(
+ [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
+ name="weights",
+ dtype=dtype)
+ conv = nn.conv2d(
+ input=inp,
+ filter=conv_filter,
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ name="conv")
+ bias = constant_op.constant(
+ [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype)
+ added = nn.bias_add(conv, bias, name="bias_add")
+ relu = nn.relu(added, "relu")
+ identity = array_ops.identity(relu, "identity")
+ pool = nn_ops.max_pool(
+ identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
+ array_ops.squeeze(pool, name=trt_test.OUTPUT_NAME)
+ return trt_test.TfTrtIntegrationTestParams(
+ graph_name="SimpleSingleEngine",
+ gdef=g.as_graph_def(),
+ input_dims=input_dims,
+ num_expected_engines=1,
+ expected_output_dims=(100, 6, 6, 6),
+ allclose_atol=1.e-03,
+ allclose_rtol=1.e-03)
+
+
+# TODO(aaroey): test graph with different dtypes.
+def _GetMultiEngineGraphDef(dtype=dtypes.float32):
+ """Create a graph containing multiple segment."""
+ input_dims = [100, 24, 24, 2]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtype, shape=[None] + input_dims[1:], name=trt_test.INPUT_NAME)
+ with g.device("/GPU:0"):
+ conv_filter = constant_op.constant(
+ [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
+ name="weights",
+ dtype=dtype)
+ conv = nn.conv2d(
+ input=inp,
+ filter=conv_filter,
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ name="conv")
+ c1 = constant_op.constant(
+ np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype)
+ p = conv * c1
+ c2 = constant_op.constant(
+ np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype)
+ q = conv / c2
+
+ edge = trt_test.TRT_INCOMPATIBLE_OP(q)
+ edge /= edge
+ r = edge + edge
+
+ p -= edge
+ q *= edge
+ s = p + q
+ s -= r
+ array_ops.squeeze(s, name=trt_test.OUTPUT_NAME)
+ return trt_test.TfTrtIntegrationTestParams(
+ graph_name="SimpleMultipleEngines",
+ gdef=g.as_graph_def(),
+ input_dims=input_dims,
+ num_expected_engines=2,
+ expected_output_dims=(100, 12, 12, 6),
+ allclose_atol=1.e-03,
+ allclose_rtol=1.e-03)
+
+
+class BaseTest(trt_test.TfTrtIntegrationTestBase):
+ """Class to test Tensorflow-TensorRT integration."""
+ pass
+
+
+if __name__ == "__main__":
+ # TODO(aaroey): add a large complex graph to test.
+ trt_test.AddTests(BaseTest,
+ [_GetSingleEngineGraphDef(),
+ _GetMultiEngineGraphDef()])
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/base_unit_test.py b/tensorflow/contrib/tensorrt/test/base_unit_test.py
deleted file mode 100644
index 8a6c648ab6..0000000000
--- a/tensorflow/contrib/tensorrt/test/base_unit_test.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class to facilitate development of integration tests."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-
-class BaseUnitTest(object):
- """Base class for unit tests in TF-TRT"""
-
- def __init__(self, log_file='log.txt'):
- self.static_mode_list = {}
- self.dynamic_mode_list = {}
- self.dummy_input = None
- self.get_network = None
- self.expect_nb_nodes = None
- self.test_name = None
- self.log_file = log_file
- self.ckpt = None
- self.allclose_rtol = 0.01
- self.allclose_atol = 0.01
- self.allclose_equal_nan = True
- # saves out graphdef
- self.debug = False
- # require node count check fail leads to test failure
- self.check_node_count = False
-
- def run(self, run_test_context):
- run_test_context.run_test(self.get_network, self.static_mode_list,
- self.dynamic_mode_list, self.dummy_input,
- self.ckpt)
- return self.log_result(run_test_context)
-
- def log_result(self, run_test_result):
- log = open(self.log_file, 'a')
- log.write(("================= model: %s\n") % (self.test_name))
-
- if self.debug:
- open(self.test_name + "_native.pb",
- 'wb').write(run_test_result.native_network.SerializeToString())
- all_success = True
- if len(run_test_result.tftrt_conversion_flag) != 0:
- log.write(" -- static_mode\n")
- for static_mode in run_test_result.tftrt_conversion_flag:
- if self.debug:
- open(self.test_name + "_" + static_mode + ".pb",
- 'wb').write(run_test_result.tftrt[static_mode].SerializeToString())
- log.write(" ----\n")
- log.write((" mode: [%s]\n") % (static_mode))
- if run_test_result.tftrt_conversion_flag[static_mode]:
- if run_test_result.tftrt_nb_nodes[static_mode] != self.expect_nb_nodes:
- log.write(
- ("[WARNING]: converted node number does not match (%d,%d,%d)!!!\n"
- ) % (run_test_result.tftrt_nb_nodes[static_mode],
- self.expect_nb_nodes, run_test_result.native_nb_nodes))
- if self.check_node_count:
- all_success = False
-
- if np.array_equal(run_test_result.tftrt_result[static_mode],
- run_test_result.native_result):
- log.write(" output: equal\n")
- elif np.allclose(
- run_test_result.tftrt_result[static_mode],
- run_test_result.native_result,
- atol=self.allclose_atol,
- rtol=self.allclose_rtol,
- equal_nan=self.allclose_equal_nan):
- log.write(" output: allclose\n")
- else:
- diff = run_test_result.tftrt_result[static_mode] - run_test_result.native_result
- log.write("[ERROR]: output does not match!!!\n")
- log.write("max diff: " + str(np.max(diff)))
- log.write("\ntftrt:\n")
- log.write(str(run_test_result.tftrt_result[static_mode]))
- log.write("\nnative:\n")
- log.write(str(run_test_result.native_result))
- log.write("\ndiff:\n")
- log.write(str(diff))
- all_success = False
- else:
- log.write("[ERROR]: conversion failed!!!\n")
- all_success = False
-
- if len(run_test_result.tftrt_dynamic_conversion_flag) != 0:
- log.write(" -- dynamic_mode\n")
- for dynamic_mode in run_test_result.tftrt_dynamic_conversion_flag:
- log.write("\n ----\n")
- log.write((" mode: [%s]\n") % (dynamic_mode))
- if run_test_result.tftrt_dynamic_conversion_flag[dynamic_mode]:
- if np.array_equal(run_test_result.tftrt_dynamic_result[dynamic_mode],
- run_test_result.native_result):
- log.write(" output: equal\n")
- elif np.allclose(run_test_result.tftrt_dynamic_result[dynamic_mode],
- run_test_result.native_result):
- log.write(" output: allclose\n")
- else:
- log.write("[ERROR]: output does not match!!!\n")
- all_success = False
- else:
- log.write("[ERROR]: conversion failed!!!\n")
- all_success = False
- return all_success
diff --git a/tensorflow/contrib/tensorrt/test/run_test.py b/tensorflow/contrib/tensorrt/test/run_test.py
deleted file mode 100644
index 4d109cc378..0000000000
--- a/tensorflow/contrib/tensorrt/test/run_test.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""script to convert and execute TF-TensorRT graph."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib import tensorrt as trt
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import importer
-from tensorflow.python.framework import ops
-from tensorflow.python.training import training
-from tensorflow.contrib.tensorrt.test.utilities import get_all_variables
-
-OUTPUT_NODE = "output"
-INPUT_NODE = "input"
-CALIB_COUNT = 5 # calibration iteration
-
-
-class RunTest:
- """base class to run TR-TRT conversion and execution"""
-
- def __init__(self):
- self.clean()
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.clean()
-
- def clean(self):
- self.tftrt = {}
- self.tftrt_conversion_flag = {}
- self.tftrt_nb_nodes = {}
- self.tftrt_result = {}
- self.tftrt_dynamic_conversion_flag = {}
- self.tftrt_dynamic_result = {}
- self.check_file = None
- self.native_network = None
-
- def run_test(self,
- network,
- static_mode_list,
- dynamic_mode_list,
- dummy_input,
- file_name=None):
- self.native_network = network()
- success = True
- initialization = False
- if file_name != None:
- initialization = True
- self.check_file = file_name
- self.native_result, self.native_nb_nodes = self.execute_graph(
- self.native_network, dummy_input, initialization)
- for mode in static_mode_list:
- try:
- self.run_static_convert_network(mode, dummy_input, initialization)
- self.tftrt_conversion_flag[mode] = True
- except Exception as inst:
- self.tftrt_conversion_flag[mode] = False
- success = False
- for mode in dynamic_mode_list:
- try:
- self.run_dynamic_convert_network(mode, dummy_input, initialization)
- self.tftrt_dynamic_conversion_flag[mode] = True
- except Exception as inst:
- self.tftrt_dynamic_conversion_flag[mode] = False
- success = False
- return success
-
- def run_dynamic_convert_network(self, mode, dummy_input, initialization=True):
- inp_dims = dummy_input.shape
- if mode == "FP32" or mode == "FP16":
- opt_config = rewriter_config_pb2.RewriterConfig()
- opt_config.optimizers.extend(["constfold", "layout"])
- custom_op = opt_config.custom_optimizers.add()
- custom_op.name = "TensorRTOptimizer"
- custom_op.parameter_map["minimum_segment_size"].i = 3
- custom_op.parameter_map["precision_mode"].s = mode
- custom_op.parameter_map["max_batch_size"].i = inp_dims[0]
- custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
- print(custom_op)
- gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
- graph_options = config_pb2.GraphOptions(rewrite_options=opt_config)
- sessconfig = config_pb2.ConfigProto(
- gpu_options=gpu_options, graph_options=graph_options)
- print(sessconfig)
- g = ops.Graph()
- ops.reset_default_graph()
- with g.as_default():
- inp, out = importer.import_graph_def(
- graph_def=self.native_network, return_elements=["input", "output"])
- inp = inp.outputs[0]
- out = out.outputs[0]
- with session.Session(config=sessconfig, graph=g) as sess:
- if (initialization):
- names_var_list = get_all_variables(sess)
- saver = training.Saver(names_var_list)
- saver.restore(sess, self.check_file)
- self.tftrt_dynamic_result[mode] = sess.run(out, {inp: dummy_input})
- else:
- raise Exception("dynamic op mode: " + mode + " not supported")
-
- def run_static_convert_network(self, mode, dummy_input, initialization=True):
- inp_dims = dummy_input.shape
- if mode == "FP32" or mode == "FP16" or mode == "INT8":
- trt_graph = trt.create_inference_graph(
- input_graph_def=self.native_network,
- outputs=[OUTPUT_NODE],
- max_batch_size=inp_dims[0],
- max_workspace_size_bytes=1 << 25,
- precision_mode=mode, # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2 # minimum number of nodes in an engine
- )
- if mode == "INT8":
- _ = self.execute_calibration(trt_graph, dummy_input, initialization)
- trt_graph = trt.calib_graph_to_infer_graph(trt_graph)
- trt_result, nb_nodes = self.execute_graph(trt_graph, dummy_input,
- initialization)
- self.tftrt[mode] = trt_graph
- self.tftrt_nb_nodes[mode] = nb_nodes
- self.tftrt_result[mode] = trt_result
- else:
- raise Exception("mode: " + mode + " not supported")
-
- def execute_graph(self, gdef, dummy_input, initialization=True):
- """Run given graphdef once."""
- gpu_options = config_pb2.GPUOptions()
- sessconfig = config_pb2.ConfigProto(gpu_options=gpu_options)
- ops.reset_default_graph()
- g = ops.Graph()
- nb_nodes = 0
- with g.as_default():
- inp, out = importer.import_graph_def(
- graph_def=gdef, return_elements=[INPUT_NODE, OUTPUT_NODE], name="")
- nb_nodes = len(g.get_operations())
- inp = inp.outputs[0]
- out = out.outputs[0]
- with session.Session(config=sessconfig, graph=g) as sess:
- if (initialization):
- names_var_list = get_all_variables(sess)
- saver = training.Saver(names_var_list)
- saver.restore(sess, self.check_file)
- val = sess.run(out, {inp: dummy_input})
- return val, nb_nodes
-
- # Use real data that is representative of the inference dataset
- # for calibration. For this test script it is random data.
- def execute_calibration(self, gdef, dummy_input, initialization=True):
- """Run given calibration graph multiple times."""
- gpu_options = config_pb2.GPUOptions()
- ops.reset_default_graph()
- g = ops.Graph()
- with g.as_default():
- inp, out = importer.import_graph_def(
- graph_def=gdef, return_elements=[INPUT_NODE, OUTPUT_NODE], name="")
- inp = inp.outputs[0]
- out = out.outputs[0]
- with session.Session(
- config=config_pb2.ConfigProto(gpu_options=gpu_options),
- graph=g) as sess:
- if (initialization):
- names_var_list = get_all_variables(sess)
- saver = training.Saver(names_var_list)
- saver.restore(sess, self.check_file)
- for _ in range(CALIB_COUNT):
- val = sess.run(out, {inp: dummy_input})
- return val
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
deleted file mode 100644
index d9c41f90d0..0000000000
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
+++ /dev/null
@@ -1,347 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Script to test TF-TensorRT integration."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from collections import namedtuple
-import itertools
-import warnings
-import numpy as np
-import six
-
-from tensorflow.contrib import tensorrt as trt
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import importer
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.platform import test
-
-INPUT_NAME = "input"
-OUTPUT_NAME = "output"
-INPUT_DIMS = [100, 24, 24, 2]
-MODE_FP32 = "FP32"
-MODE_FP16 = "FP16"
-MODE_INT8 = "INT8"
-
-if six.PY2:
- to_bytes = lambda s: s
- to_string = lambda s: s
-else:
- to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape")
- to_string = lambda s: s.decode("utf-8")
-
-
-# TODO(aaroey): test graph with different dtypes.
-def GetSingleEngineGraphDef(dtype=dtypes.float32):
- """Create a graph containing single segment."""
- g = ops.Graph()
- with g.as_default():
- inp = array_ops.placeholder(
- dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME)
- with g.device("/GPU:0"):
- conv_filter = constant_op.constant(
- [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
- name="weights",
- dtype=dtype)
- conv = nn.conv2d(
- input=inp,
- filter=conv_filter,
- strides=[1, 2, 2, 1],
- padding="SAME",
- name="conv")
- bias = constant_op.constant(
- [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype)
- added = nn.bias_add(conv, bias, name="bias_add")
- relu = nn.relu(added, "relu")
- identity = array_ops.identity(relu, "identity")
- pool = nn_ops.max_pool(
- identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
- array_ops.squeeze(pool, name=OUTPUT_NAME)
- return g.as_graph_def()
-
-
-# TODO(aaroey): test graph with different dtypes.
-def GetMultiEngineGraphDef(dtype=dtypes.float32):
- """Create a graph containing multiple segment."""
- g = ops.Graph()
- with g.as_default():
- inp = array_ops.placeholder(
- dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME)
- with g.device("/GPU:0"):
- conv_filter = constant_op.constant(
- [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
- name="weights",
- dtype=dtype)
- conv = nn.conv2d(
- input=inp,
- filter=conv_filter,
- strides=[1, 2, 2, 1],
- padding="SAME",
- name="conv")
- c1 = constant_op.constant(
- np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype)
- p = conv * c1
- c2 = constant_op.constant(
- np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype)
- q = conv / c2
-
- edge = math_ops.sin(q)
- edge /= edge
- r = edge + edge
-
- p -= edge
- q *= edge
- s = p + q
- s -= r
- array_ops.squeeze(s, name=OUTPUT_NAME)
- return g.as_graph_def()
-
-
-TestGraph = namedtuple("TestGraph",
- ["gdef", "num_expected_engines", "expected_output_dims"])
-
-TEST_GRAPHS = {
- "SingleEngineGraph":
- TestGraph(
- gdef=GetSingleEngineGraphDef(),
- num_expected_engines=1,
- expected_output_dims=(100, 6, 6, 6)),
- "MultiEngineGraph":
- TestGraph(
- gdef=GetMultiEngineGraphDef(),
- num_expected_engines=2,
- expected_output_dims=(100, 12, 12, 6)),
- # TODO(aaroey): add a large complex graph to test.
-}
-
-
-class TfTrtIntegrationTest(test_util.TensorFlowTestCase):
- """Class to test Tensorflow-TensorRT integration."""
-
- def setUp(self):
- """Setup method."""
- super(TfTrtIntegrationTest, self).setUp()
- warnings.simplefilter("always")
- self._input = np.random.random_sample(INPUT_DIMS)
-
- def _GetConfigProto(self,
- use_optimizer,
- precision_mode=None,
- is_dynamic_op=None):
- if use_optimizer:
- rewriter_cfg = rewriter_config_pb2.RewriterConfig()
- rewriter_cfg.optimizers.extend(["constfold", "layout"])
- custom_op = rewriter_cfg.custom_optimizers.add()
- custom_op.name = "TensorRTOptimizer"
- custom_op.parameter_map["minimum_segment_size"].i = 3
- custom_op.parameter_map["max_batch_size"].i = self._input.shape[0]
- custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op
- custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
- custom_op.parameter_map["precision_mode"].s = to_bytes(precision_mode)
- graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
- else:
- graph_options = config_pb2.GraphOptions()
-
- gpu_options = config_pb2.GPUOptions()
- if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
- gpu_options.per_process_gpu_memory_fraction = 0.50
-
- config = config_pb2.ConfigProto(
- gpu_options=gpu_options, graph_options=graph_options)
- return config
-
- def _RunGraph(self, graph_key, gdef, input_data, config, num_runs=2):
- """Run given graphdef multiple times."""
- g = ops.Graph()
- with g.as_default():
- inp, out = importer.import_graph_def(
- graph_def=gdef, return_elements=[INPUT_NAME, OUTPUT_NAME], name="")
- inp = inp.outputs[0]
- out = out.outputs[0]
- with self.test_session(
- graph=g, config=config, use_gpu=True, force_gpu=True) as sess:
- val = None
- # Defaults to 2 runs to verify result across multiple runs is same.
- for _ in range(num_runs):
- new_val = sess.run(out, {inp: input_data})
- self.assertEquals(TEST_GRAPHS[graph_key].expected_output_dims,
- new_val.shape)
- if val is not None:
- self.assertAllEqual(new_val, val)
- val = new_val
- return val
-
- # Use real data that is representative of the inference dataset
- # for calibration. For this test script it is random data.
- def _RunCalibration(self, graph_key, gdef, input_data, config):
- """Run calibration on given graph."""
- return self._RunGraph(graph_key, gdef, input_data, config, 30)
-
- def _GetTrtGraph(self, gdef, precision_mode, is_dynamic_op):
- """Return trt converted graph."""
- return trt.create_inference_graph(
- input_graph_def=gdef,
- outputs=[OUTPUT_NAME],
- max_batch_size=self._input.shape[0],
- max_workspace_size_bytes=1 << 25,
- precision_mode=precision_mode,
- minimum_segment_size=2,
- is_dynamic_op=is_dynamic_op)
-
- def _VerifyGraphDef(self,
- graph_key,
- gdef,
- precision_mode=None,
- is_calibrated=None,
- dynamic_engine=None):
- num_engines = 0
- for n in gdef.node:
- if n.op == "TRTEngineOp":
- num_engines += 1
- self.assertNotEqual("", n.attr["serialized_segment"].s)
- self.assertNotEqual("", n.attr["segment_funcdef_name"].s)
- self.assertEquals(n.attr["precision_mode"].s, precision_mode)
- self.assertEquals(n.attr["static_engine"].b, not dynamic_engine)
- if precision_mode == MODE_INT8 and is_calibrated:
- self.assertNotEqual("", n.attr["calibration_data"].s)
- else:
- self.assertEquals("", n.attr["calibration_data"].s)
- if precision_mode is None:
- self.assertEquals(num_engines, 0)
- else:
- self.assertEquals(num_engines,
- TEST_GRAPHS[graph_key].num_expected_engines)
-
- def _RunTest(self, graph_key, use_optimizer, precision_mode,
- dynamic_infer_engine, dynamic_calib_engine):
- assert precision_mode in [MODE_FP32, MODE_FP16, MODE_INT8]
- input_gdef = TEST_GRAPHS[graph_key].gdef
- self._VerifyGraphDef(graph_key, input_gdef)
-
- # Get reference result without running trt.
- config_no_trt = self._GetConfigProto(False)
- print("Running original graph w/o trt, config:\n%s" % str(config_no_trt))
- ref_result = self._RunGraph(graph_key, input_gdef, self._input,
- config_no_trt)
-
- # Run calibration if necessary.
- if precision_mode == MODE_INT8:
-
- calib_config = self._GetConfigProto(use_optimizer, precision_mode,
- dynamic_calib_engine)
- print("Running calibration graph, config:\n%s" % str(calib_config))
- if use_optimizer:
- self.assertTrue(False)
- # TODO(aaroey): uncomment this and get infer_gdef when this mode is
- # supported.
- # result = self._RunCalibration(graph_key, input_gdef, self._input,
- # calib_config)
- else:
- calib_gdef = self._GetTrtGraph(input_gdef, precision_mode,
- dynamic_calib_engine)
- self._VerifyGraphDef(graph_key, calib_gdef, precision_mode, False,
- dynamic_calib_engine)
- result = self._RunCalibration(graph_key, calib_gdef, self._input,
- calib_config)
- infer_gdef = trt.calib_graph_to_infer_graph(calib_gdef)
- self._VerifyGraphDef(graph_key, infer_gdef, precision_mode, True,
- dynamic_calib_engine)
- self.assertAllClose(ref_result, result, rtol=1.e-03)
- else:
- infer_gdef = input_gdef
-
- # Run inference.
- infer_config = self._GetConfigProto(use_optimizer, precision_mode,
- dynamic_infer_engine)
- print("Running final inference graph, config:\n%s" % str(infer_config))
- if use_optimizer:
- result = self._RunGraph(graph_key, infer_gdef, self._input, infer_config)
- else:
- trt_infer_gdef = self._GetTrtGraph(infer_gdef, precision_mode,
- dynamic_infer_engine)
- self._VerifyGraphDef(graph_key, trt_infer_gdef, precision_mode, True,
- dynamic_infer_engine)
- result = self._RunGraph(graph_key, trt_infer_gdef, self._input,
- infer_config)
- self.assertAllClose(ref_result, result, rtol=1.e-03)
-
- def testIdempotence(self):
- # Test that applying tensorrt optimizer or offline conversion tools multiple
- # times to the same graph will result in same graph.
- # TODO(aaroey): implement this.
- pass
-
-
-def GetTests():
-
- def _GetTest(g, u, p, i, c):
-
- def _Test(self):
- print("Running test with parameters: graph_key=%s, use_optimizer=%s, "
- "precision_mode=%s, dynamic_infer_engine=%s, "
- "dynamic_calib_engine=%s" % (g, u, p, i, c))
- self._RunTest(g, u, p, i, c)
-
- return _Test
-
- use_optimizer_options = [False, True]
- precision_mode_options = [MODE_FP32, MODE_FP16, MODE_INT8]
- dynamic_infer_engine_options = [False, True]
- dynamic_calib_engine_options = [False, True]
- for (graph_key, use_optimizer, precision_mode,
- dynamic_infer_engine, dynamic_calib_engine) in itertools.product(
- TEST_GRAPHS, use_optimizer_options, precision_mode_options,
- dynamic_infer_engine_options, dynamic_calib_engine_options):
- if precision_mode == MODE_INT8:
- if not dynamic_calib_engine and dynamic_infer_engine:
- # TODO(aaroey): test this case, the conversion from static calibration
- # engine to dynamic inference engine should be a noop.
- continue
- if use_optimizer:
- # TODO(aaroey): if use_optimizer is True we need to get the inference
- # graphdef using custom python wrapper class, which is not currently
- # supported yet.
- continue
- if not dynamic_calib_engine:
- # TODO(aaroey): construction of static calibration engine is not
- # supported yet.
- continue
- if dynamic_calib_engine and not dynamic_infer_engine:
- # TODO(aaroey): construction of static inference engine using dynamic
- # calibration engine is not supported yet.
- continue
- else: # In non int8 mode.
- if dynamic_calib_engine:
- # dynamic_calib_engine doesn't affect non-int8 modes, so just let
- # related tests run once on dynamic_calib_engine=False.
- continue
- yield _GetTest(graph_key, use_optimizer, precision_mode,
- dynamic_infer_engine, dynamic_calib_engine)
-
-
-if __name__ == "__main__":
- for index, t in enumerate(GetTests()):
- setattr(TfTrtIntegrationTest, "testTfTRT_" + str(index), t)
- test.main()
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
new file mode 100644
index 0000000000..980cc87366
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -0,0 +1,293 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+import itertools
+import warnings
+import numpy as np
+import re
+import six
+
+from tensorflow.contrib import tensorrt as trt
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
+
+TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [
+ "graph_name", "gdef", "input_dims", "num_expected_engines",
+ "expected_output_dims", "allclose_atol", "allclose_rtol"
+])
+
+INPUT_NAME = "input"
+OUTPUT_NAME = "output"
+TRT_INCOMPATIBLE_OP = math_ops.sin
+PRECISION_MODES = ["FP32", "FP16", "INT8"]
+
+
+def IsQuantizationMode(mode):
+ return mode == "INT8"
+
+
+class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
+ """Class to test Tensorflow-TensorRT integration."""
+
+ def _ToBytes(self, s):
+ if six.PY2:
+ return s
+ else:
+ return s.encode("utf-8")
+
+ def _ToString(self, s):
+ if six.PY2:
+ return s
+ else:
+ return s.decode("utf-8")
+
+ def setUp(self):
+ """Setup method."""
+ super(TfTrtIntegrationTestBase, self).setUp()
+ warnings.simplefilter("always")
+
+ def _GetConfigProto(self,
+ params,
+ use_optimizer,
+ precision_mode=None,
+ is_dynamic_op=None):
+ """Get config proto based on specific settings."""
+ if use_optimizer:
+ rewriter_cfg = rewriter_config_pb2.RewriterConfig()
+ rewriter_cfg.optimizers.extend(["constfold", "layout"])
+ custom_op = rewriter_cfg.custom_optimizers.add()
+ custom_op.name = "TensorRTOptimizer"
+ custom_op.parameter_map["minimum_segment_size"].i = 3
+ custom_op.parameter_map["max_batch_size"].i = params.input_dims[0]
+ custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
+ custom_op.parameter_map["precision_mode"].s = self._ToBytes(
+ precision_mode)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
+ else:
+ graph_options = config_pb2.GraphOptions()
+
+ gpu_options = config_pb2.GPUOptions()
+ if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
+ gpu_options.per_process_gpu_memory_fraction = 0.50
+
+ config = config_pb2.ConfigProto(
+ gpu_options=gpu_options, graph_options=graph_options)
+ return config
+
+ def _RunGraph(self, params, gdef, input_data, config, num_runs=2):
+ """Run given graphdef multiple times."""
+ g = ops.Graph()
+ with g.as_default():
+ inp, out = importer.import_graph_def(
+ graph_def=gdef, return_elements=[INPUT_NAME, OUTPUT_NAME], name="")
+ inp = inp.outputs[0]
+ out = out.outputs[0]
+ with self.test_session(
+ graph=g, config=config, use_gpu=True, force_gpu=True) as sess:
+ val = None
+ # Defaults to 2 runs to verify result across multiple runs is same.
+ for _ in range(num_runs):
+ new_val = sess.run(out, {inp: input_data})
+ self.assertEquals(params.expected_output_dims, new_val.shape)
+ if val is not None:
+ self.assertAllEqual(new_val, val)
+ val = new_val
+ return val
+
+ # Use real data that is representative of the inference dataset
+ # for calibration. For this test script it is random data.
+ def _RunCalibration(self, params, gdef, input_data, config):
+ """Run calibration on given graph."""
+ return self._RunGraph(params, gdef, input_data, config, 30)
+
+ def _GetTrtGraphDef(self, params, gdef, precision_mode, is_dynamic_op):
+ """Return trt converted graphdef."""
+ return trt.create_inference_graph(
+ input_graph_def=gdef,
+ outputs=[OUTPUT_NAME],
+ max_batch_size=params.input_dims[0],
+ max_workspace_size_bytes=1 << 25,
+ precision_mode=precision_mode,
+ minimum_segment_size=2,
+ is_dynamic_op=is_dynamic_op)
+
+ def _VerifyGraphDef(self,
+ params,
+ gdef,
+ precision_mode=None,
+ is_calibrated=None,
+ dynamic_engine=None):
+ num_engines = 0
+ for n in gdef.node:
+ if n.op == "TRTEngineOp":
+ num_engines += 1
+ self.assertNotEqual("", n.attr["serialized_segment"].s)
+ self.assertNotEqual("", n.attr["segment_funcdef_name"].s)
+ self.assertEquals(n.attr["precision_mode"].s, precision_mode)
+ self.assertEquals(n.attr["static_engine"].b, not dynamic_engine)
+ if IsQuantizationMode(precision_mode) and is_calibrated:
+ self.assertNotEqual("", n.attr["calibration_data"].s)
+ else:
+ self.assertEquals("", n.attr["calibration_data"].s)
+ if precision_mode is None:
+ self.assertEquals(num_engines, 0)
+ else:
+ self.assertEquals(num_engines, params.num_expected_engines)
+
+ def _RunTest(self, params, use_optimizer, precision_mode,
+ dynamic_infer_engine, dynamic_calib_engine):
+ assert precision_mode in PRECISION_MODES
+ inp = np.random.random_sample(params.input_dims)
+ input_gdef = params.gdef
+ self._VerifyGraphDef(params, input_gdef)
+
+ # Get reference result without running trt.
+ config_no_trt = self._GetConfigProto(params, False)
+ logging.info("Running original graph w/o trt, config:\n%s",
+ str(config_no_trt))
+ ref_result = self._RunGraph(params, input_gdef, inp, config_no_trt)
+
+ # Run calibration if necessary.
+ if IsQuantizationMode(precision_mode):
+
+ calib_config = self._GetConfigProto(params, use_optimizer, precision_mode,
+ dynamic_calib_engine)
+ logging.info("Running calibration graph, config:\n%s", str(calib_config))
+ if use_optimizer:
+ self.assertTrue(False)
+ # TODO(aaroey): uncomment this and get infer_gdef when this mode is
+ # supported.
+ # result = self._RunCalibration(params, input_gdef, inp, calib_config)
+ else:
+ calib_gdef = self._GetTrtGraphDef(params, input_gdef, precision_mode,
+ dynamic_calib_engine)
+ self._VerifyGraphDef(params, calib_gdef, precision_mode, False,
+ dynamic_calib_engine)
+ result = self._RunCalibration(params, calib_gdef, inp, calib_config)
+ infer_gdef = trt.calib_graph_to_infer_graph(calib_gdef)
+ self._VerifyGraphDef(params, infer_gdef, precision_mode, True,
+ dynamic_calib_engine)
+
+ self.assertAllClose(
+ ref_result,
+ result,
+ atol=params.allclose_atol,
+ rtol=params.allclose_rtol)
+ else:
+ infer_gdef = input_gdef
+
+ # Run inference.
+ infer_config = self._GetConfigProto(params, use_optimizer, precision_mode,
+ dynamic_infer_engine)
+ logging.info("Running final inference graph, config:\n%s",
+ str(infer_config))
+ if use_optimizer:
+ result = self._RunGraph(params, infer_gdef, inp, infer_config)
+ else:
+ trt_infer_gdef = self._GetTrtGraphDef(params, infer_gdef, precision_mode,
+ dynamic_infer_engine)
+ self._VerifyGraphDef(params, trt_infer_gdef, precision_mode, True,
+ dynamic_infer_engine)
+ result = self._RunGraph(params, trt_infer_gdef, inp, infer_config)
+
+ self.assertAllClose(
+ ref_result,
+ result,
+ atol=params.allclose_atol,
+ rtol=params.allclose_rtol)
+
+ def testIdempotence(self):
+ # Test that applying tensorrt optimizer or offline conversion tools multiple
+ # times to the same graph will result in same graph.
+ # TODO(aaroey): implement this.
+ pass
+
+
+def AddTests(test_class, params_list):
+
+ def _GetTest(params, use_optimizer, precision_mode, dynamic_infer_engine,
+ dynamic_calib_engine):
+
+ def _Test(self):
+ logging.info(
+ "Running test with parameters: graph_name=%s, "
+ "use_optimizer=%s, precision_mode=%s, "
+ "dynamic_infer_engine=%s, dynamic_calib_engine=%s", params.graph_name,
+ use_optimizer, precision_mode, dynamic_infer_engine,
+ dynamic_calib_engine)
+ self._RunTest(params, use_optimizer, precision_mode, dynamic_infer_engine,
+ dynamic_calib_engine)
+
+ return _Test
+
+ use_optimizer_options = [False, True]
+ dynamic_infer_engine_options = [False, True]
+ dynamic_calib_engine_options = [False, True]
+ for (params, use_optimizer, precision_mode,
+ dynamic_infer_engine, dynamic_calib_engine) in itertools.product(
+ params_list, use_optimizer_options, PRECISION_MODES,
+ dynamic_infer_engine_options, dynamic_calib_engine_options):
+ if IsQuantizationMode(precision_mode):
+ if not dynamic_calib_engine and dynamic_infer_engine:
+ # TODO(aaroey): test this case, the conversion from static calibration
+ # engine to dynamic inference engine should be a noop.
+ continue
+ if use_optimizer:
+ # TODO(aaroey): if use_optimizer is True we need to get the inference
+ # graphdef using custom python wrapper class, which is not currently
+ # supported yet.
+ continue
+ if not dynamic_calib_engine:
+ # TODO(aaroey): construction of static calibration engine is not
+ # supported yet.
+ continue
+ if dynamic_calib_engine and not dynamic_infer_engine:
+ # TODO(aaroey): construction of static inference engine using dynamic
+ # calibration engine is not supported yet.
+ continue
+ else: # In non int8 mode.
+ if dynamic_calib_engine:
+ # dynamic_calib_engine doesn't affect non-int8 modes, so just let
+ # related tests run once on dynamic_calib_engine=False.
+ continue
+
+ conversion = "OptimizerConversion" if use_optimizer else "ToolConversion"
+ infer_engine_type = ("DynamicInferEngine"
+ if dynamic_infer_engine else "StaticInferEngine")
+ calib_engine_type = ""
+ if precision_mode == "INT8":
+ calib_engine_type = ("DynamicCalibEngine"
+ if dynamic_calib_engine else "StaticCalibEngine")
+ test_name = "%s_%s_%s_%s%s" % (re.sub(
+ "[^a-zA-Z0-9]+", "", params.graph_name), conversion, precision_mode,
+ infer_engine_type, ("_" + calib_engine_type)
+ if len(calib_engine_type) else "")
+ setattr(
+ test_class, "testTfTRT_" + test_name,
+ _GetTest(params, use_optimizer, precision_mode, dynamic_infer_engine,
+ dynamic_calib_engine))
diff --git a/tensorflow/contrib/tensorrt/test/unit_tests.py b/tensorflow/contrib/tensorrt/test/unit_tests.py
deleted file mode 100644
index ac6e3b13ee..0000000000
--- a/tensorflow/contrib/tensorrt/test/unit_tests.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Script to execute and log all integration tests."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.tensorrt.test.batch_matmul_test import BatchMatMulTest
-from tensorflow.contrib.tensorrt.test.biasadd_matmul_test import BiasaddMatMulTest
-from tensorflow.contrib.tensorrt.test.binary_tensor_weight_broadcast_test import BinaryTensorWeightBroadcastTest
-from tensorflow.contrib.tensorrt.test.concatenation_test import ConcatenationTest
-from tensorflow.contrib.tensorrt.test.multi_connection_neighbor_engine_test import MultiConnectionNeighborEngineTest
-from tensorflow.contrib.tensorrt.test.neighboring_engine_test import NeighboringEngineTest
-from tensorflow.contrib.tensorrt.test.unary_test import UnaryTest
-from tensorflow.contrib.tensorrt.test.vgg_block_nchw_test import VGGBlockNCHWTest
-from tensorflow.contrib.tensorrt.test.vgg_block_test import VGGBlockTest
-from tensorflow.contrib.tensorrt.test.const_broadcast_test import ConstBroadcastTest
-
-from tensorflow.contrib.tensorrt.test.run_test import RunTest
-
-tests = 0
-passed_test = 0
-
-failed_list = []
-test_list = []
-
-test_list.append(BatchMatMulTest())
-test_list.append(BiasaddMatMulTest())
-test_list.append(BinaryTensorWeightBroadcastTest())
-test_list.append(ConcatenationTest())
-test_list.append(NeighboringEngineTest())
-test_list.append(UnaryTest())
-test_list.append(VGGBlockNCHWTest())
-test_list.append(VGGBlockTest())
-test_list.append(MultiConnectionNeighborEngineTest())
-test_list.append(ConstBroadcastTest())
-
-for test in test_list:
- test.debug = True
- test.check_node_count = False
- with RunTest() as context:
- tests += 1
- if test.run(context):
- passed_test += 1
- else:
- failed_list.append(test.test_name)
- print("Failed test: %s\n", test.test_name)
-
-if passed_test == tests:
- print("Passed\n")
-else:
- print(("%d out of %d passed\n -- failed list:") % (passed_test, tests))
- for test in failed_list:
- print(" - " + test)
diff --git a/tensorflow/contrib/tensorrt/test/utilities.py b/tensorflow/contrib/tensorrt/test/utilities.py
deleted file mode 100644
index 0ea5f5b883..0000000000
--- a/tensorflow/contrib/tensorrt/test/utilities.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utilities script for TF-TensorRT integration tests."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.ops import variables
-
-
-def get_all_variables(sess):
- var_names = sess.run(variables.report_uninitialized_variables())
- names_var_list = {}
- for name in var_names:
- names_var_list[name] = sess.graph.get_tensor_by_name(name + ":0")
- print(var_names)
- return names_var_list