diff options
-rw-r--r-- | tensorflow/contrib/BUILD | 6 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 32 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_graph.cc | 8 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 9 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/utils.cc | 35 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/utils.h | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/python/__init__.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/python/trt_convert.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/resources/trt_allocator.h | 3 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/trt_conversion.i | 12 | ||||
-rw-r--r-- | tensorflow/tools/pip_package/BUILD | 5 |
12 files changed, 81 insertions, 38 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 60be9db263..1322056d80 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -7,7 +7,6 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") load("//tensorflow:tensorflow.bzl", "if_not_windows") load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda") @@ -103,6 +102,7 @@ py_library( "//tensorflow/contrib/summary:summary", "//tensorflow/contrib/tensor_forest:init_py", "//tensorflow/contrib/tensorboard", + "//tensorflow/contrib/tensorrt:init_py", "//tensorflow/contrib/testing:testing_py", "//tensorflow/contrib/text:text_py", "//tensorflow/contrib/tfprof", @@ -113,9 +113,7 @@ py_library( "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", "//tensorflow/python/estimator:estimator_py", - ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ - "//tensorflow/contrib/tensorrt:init_py", - ]) + select({ + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({ "//tensorflow:with_kafka_support_windows_override": [], "//tensorflow:with_kafka_support": [ "//tensorflow/contrib/kafka", diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index adda0b758b..cb2daa7b12 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -11,7 +11,7 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "py_test", + "cuda_py_test", "tf_cc_test", "tf_copts", "tf_cuda_library", @@ -32,10 +32,7 @@ tf_cuda_cc_test( name = "tensorrt_test_cc", size = "small", srcs = ["tensorrt_test.cc"], - tags = [ - "manual", - "notap", - ], + tags = ["no_windows"], deps = [ "//tensorflow/core:lib", "//tensorflow/core:test", @@ -185,6 +182,9 @@ tf_py_wrap_cc( name = "wrap_conversion", srcs = ["trt_conversion.i"], copts = tf_copts(), + swig_includes = [ + "//tensorflow/python:platform/base.i", + ], deps = [ ":trt_conversion", ":trt_engine_op_kernel", @@ -275,6 +275,7 @@ tf_cc_test( name = "segment_test", size = "small", srcs = ["segment/segment_test.cc"], + tags = ["no_windows"], deps = [ ":segment", "//tensorflow/c:c_api", @@ -310,10 +311,6 @@ tf_cuda_cc_test( name = "trt_plugin_factory_test", size = "small", srcs = ["plugin/trt_plugin_factory_test.cc"], - tags = [ - "manual", - "notap", - ], deps = [ ":trt_plugins", "//tensorflow/core:lib", @@ -325,23 +322,24 @@ tf_cuda_cc_test( ]), ) -py_test( +cuda_py_test( name = "tf_trt_integration_test", srcs = ["test/tf_trt_integration_test.py"], - main = "test/tf_trt_integration_test.py", - srcs_version = "PY2AND3", - tags = [ - "manual", - "notap", - ], - deps = [ + additional_deps = [ ":init_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], + main = "test/tf_trt_integration_test.py", + tags = [ + "no_windows", + "nomac", + ], ) cc_library( name = "utils", + srcs = ["convert/utils.cc"], hdrs = ["convert/utils.h"], + copts = tf_copts(), ) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 63d8eec7db..089b03dcb5 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -624,7 +624,9 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( edge->src()->output_type(edge->src_output())); VLOG(1) << " input " << nout.node << ":" << nout.index << " dtype=" << tensorflow::DataTypeString(nout.data_type); - node_builder.Input({nout}); + // nvcc complains that Input(<brace-enclosed initializer list>) is + // ambiguous, so do not use Input({nout}). + node_builder.Input(nout); TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) .Attr("index", i) .Finalize(&nd)); @@ -829,7 +831,9 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { // The allocator is used to build the engine. The build and the built engine // will be destroyed after we get the serialized engine string, so it's fine // to use unique_ptr here. - std::unique_ptr<nvinfer1::IGpuAllocator> alloc; + // TODO(aaroey): nvinfer1::IGpuAllocator doesn't have a virtual destructor + // and destructing the unique_ptr will result in segfault, fix it. + std::unique_ptr<TRTDeviceAllocator> alloc; auto device_alloc = GetDeviceAndAllocator(params, engine); int cuda_device_id = 0; if (device_alloc.first >= 0) { diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 0ee708bc1c..65fef27533 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -630,6 +630,7 @@ class Converter { const string& op = node_def.op(); std::vector<TRT_TensorOrWeights> outputs; if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) { + // TODO(aaroey): plugin_converter_ is not set, fix it. TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); } else { if (!op_registry_.count(op)) { @@ -1756,7 +1757,7 @@ tensorflow::Status ConvertBinary(Converter& ctx, } else { #else } - if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor() || !status.ok()) { + if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) { #endif status = BinaryTensorOpTensor(ctx, node_def, inputs.at(0), inputs.at(1), outputs); @@ -2371,10 +2372,7 @@ tensorflow::Status ConvertMatMul(Converter& ctx, node_def.name()); } - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - TFAttrs attrs(node_def); - // TODO(jie): INT32 should be converted? tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T"); if (tf_dtype != tensorflow::DataType::DT_FLOAT && @@ -2383,12 +2381,9 @@ tensorflow::Status ConvertMatMul(Converter& ctx, "data type is not supported, for node " + node_def.name() + " got " + tensorflow::DataTypeString(tf_dtype)); } - bool transpose_a = attrs.get<bool>("transpose_a"); bool transpose_b = attrs.get<bool>("transpose_b"); - nvinfer1::ITensor* output_tensor; - // FullyConnected: if (transpose_a) { return tensorflow::errors::Internal( diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/contrib/tensorrt/convert/utils.cc new file mode 100644 index 0000000000..24591cf84b --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/utils.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/convert/utils.h" + +namespace tensorflow { +namespace tensorrt { + +bool IsGoogleTensorRTEnabled() { + // TODO(laigd): consider also checking if tensorrt shared libraries are + // accessible. We can then direct users to this function to make sure they can + // safely write code that uses tensorrt conditionally. E.g. if it does not + // check for for tensorrt, and user mistakenly uses tensorrt, they will just + // crash and burn. +#ifdef GOOGLE_TENSORRT + return true; +#else + return false; +#endif +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/contrib/tensorrt/convert/utils.h index f601c06701..8b5f4d614a 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.h +++ b/tensorflow/contrib/tensorrt/convert/utils.h @@ -31,6 +31,8 @@ struct TrtDestroyer { template <typename T> using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>; +bool IsGoogleTensorRTEnabled(); + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index 0b2321b5fc..fe4fa166a1 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -22,4 +22,5 @@ from __future__ import print_function from tensorflow.contrib.tensorrt.python.ops import trt_engine_op from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph +from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 79f512dbcf..2b67931661 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -23,6 +23,7 @@ import six as _six from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version +from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h index c5d2cec730..97ac82ca5d 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h @@ -51,6 +51,9 @@ class TRTDeviceAllocator : public nvinfer1::IGpuAllocator { // Allocator implementation wrapping TF device allocators. public: TRTDeviceAllocator(tensorflow::Allocator* allocator); + + // TODO(aaroey): base class doesn't have a virtual destructor, work with + // Nvidia to fix it. virtual ~TRTDeviceAllocator() { VLOG(1) << "Destroying allocator attached to " << allocator_->Name(); } diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py index 3c68c6e4e9..7c3ef498c9 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py @@ -347,6 +347,7 @@ def GetTests(): if __name__ == "__main__": - for index, t in enumerate(GetTests()): - setattr(TfTrtIntegrationTest, "testTfTRT_" + str(index), t) + if trt.is_tensorrt_enabled(): + for index, t in enumerate(GetTests()): + setattr(TfTrtIntegrationTest, "testTfTRT_" + str(index), t) test.main() diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index d6628cd1eb..422740fdf6 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -100,6 +100,7 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/stat_summarizer.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" %} %ignoreall @@ -108,6 +109,7 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); %unignore calib_convert; %unignore get_linked_tensorrt_version; %unignore get_loaded_tensorrt_version; +%unignore is_tensorrt_enabled; %{ @@ -140,7 +142,7 @@ std::pair<string, string> trt_convert( return std::pair<string, string>{out_status, ""}; } - if(precision_mode < 0 || precision_mode > 2){ + if (precision_mode < 0 || precision_mode > 2) { out_status = "InvalidArgument;Invalid precision_mode"; return std::pair<string, string>{out_status, ""}; } @@ -232,7 +234,8 @@ version_struct get_linked_tensorrt_version() { #endif // GOOGLE_CUDA && GOOGLE_TENSORRT return s; } -version_struct get_loaded_tensorrt_version(){ + +version_struct get_loaded_tensorrt_version() { // Return the version from the loaded library. version_struct s; #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -244,6 +247,10 @@ version_struct get_loaded_tensorrt_version(){ return s; } +bool is_tensorrt_enabled() { + return tensorflow::tensorrt::IsGoogleTensorRTEnabled(); +} + %} std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op); @@ -258,5 +265,6 @@ std::pair<string, string> trt_convert(string graph_def_string, std::vector<int> cached_engine_batches); version_struct get_linked_tensorrt_version(); version_struct get_loaded_tensorrt_version(); +bool is_tensorrt_enabled(); %unignoreall diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 6d876b786a..e661fb1adc 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -11,7 +11,6 @@ load( ) load("//third_party/mkl:build_defs.bzl", "if_mkl") load("//tensorflow:tensorflow.bzl", "if_cuda") -load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib") load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps") @@ -190,9 +189,7 @@ sh_binary( "//tensorflow/contrib/lite/python:tflite_convert", "//tensorflow/contrib/lite/toco/python:toco_from_protos", ], - }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([ - "//tensorflow/contrib/tensorrt:init_py", - ]), + }) + if_mkl(["//third_party/mkl:intel_binary_blob"]), ) # A genrule for generating a marker file for the pip package on Windows |