aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tensorrt/BUILD26
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc4
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc84
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc4
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin.cc89
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin.h81
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc81
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h83
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc36
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h51
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc4
11 files changed, 528 insertions, 15 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 2f316767b3..98f18835b0 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -67,6 +67,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = [
":trt_logging",
+ ":trt_plugins",
] + if_tensorrt([
"@local_config_tensorrt//:nv_infer",
]) + tf_custom_op_library_additional_deps(),
@@ -86,6 +87,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":trt_logging",
+ ":trt_plugins",
":trt_resources",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib_proto_parsing",
@@ -222,6 +224,7 @@ tf_cuda_library(
],
deps = [
":segment",
+ ":trt_plugins",
":trt_logging",
":trt_resources",
"//tensorflow/core/grappler:grappler_item",
@@ -272,3 +275,26 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+# Library for the plugin factory
+#cc_library(
+tf_cuda_library(
+ name = "trt_plugins",
+ srcs = [
+ "plugin/trt_plugin.cc",
+ "plugin/trt_plugin_factory.cc",
+ "plugin/trt_plugin_utils.cc",
+ ],
+ hdrs = [
+ "plugin/trt_plugin.h",
+ "plugin/trt_plugin_factory.h",
+ "plugin/trt_plugin_utils.h",
+ ],
+ linkstatic = 1,
+ deps = [
+ #"@protobuf_archive//:protobuf_headers",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index b412b296e0..899e1721e6 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include <list>
#include <map>
@@ -75,7 +76,8 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) {
// TODO(ben,jie): ...
};
// LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
- return candidate_ops.count(node->type_string());
+ return (candidate_ops.count(node->type_string()) ||
+ PluginFactoryTensorRT::GetInstance().IsPlugin(&node->type_string()));
}
void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 567b4af88d..a03c1e224a 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include <algorithm>
#include <list>
@@ -246,6 +247,15 @@ class TFAttrs {
return attrs_.count(key) ? this->get<T>(key) : default_value;
}
+ std::vector<string> GetAllAttrKey() {
+ std::vector<string> attr_list;
+ for (AttrMap::iterator iter = attrs_.begin(); iter != attrs_.end();
+ iter++) {
+ attr_list.emplace_back(iter->first);
+ }
+ return attr_list;
+ }
+
private:
typedef std::map<string, tensorflow::AttrValue const*> AttrMap;
AttrMap attrs_;
@@ -263,6 +273,12 @@ std::vector<int> TFAttrs::get<std::vector<int>>(string key) const {
}
template <>
+std::vector<float> TFAttrs::get<std::vector<float>>(string key) const {
+ auto attr = this->at(key)->list().f();
+ return std::vector<float>(attr.begin(), attr.end());
+}
+
+template <>
std::vector<string> TFAttrs::get<std::vector<string>>(string key) const {
auto attr = this->at(key)->list().s();
return std::vector<string>(attr.begin(), attr.end());
@@ -424,6 +440,7 @@ using OpConverter =
class Converter {
std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
std::unordered_map<string, OpConverter> op_registry_;
+ OpConverter plugin_converter_;
nvinfer1::INetworkDefinition* trt_network_;
std::list<std::vector<uint8_t>> temp_bufs_;
tensorflow::tensorrt::TRTWeightStore* weight_store_;
@@ -444,8 +461,8 @@ class Converter {
* remove this and annotate the edge as a control dependency.
************************************************************************/
// skip control nodes
- if (input_name[0] == '^' ) continue;
- string name = input_name;
+ if (input_name[0] == '^') continue;
+ string name = input_name;
auto first = name.find_first_of(':');
if (first != string::npos && first + 2 == name.size() &&
name[first + 1] == '0')
@@ -490,13 +507,17 @@ class Converter {
std::vector<TRT_TensorOrWeights> inputs;
TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs));
string op = node_def.op();
- if (!op_registry_.count(op)) {
- return tensorflow::errors::Unimplemented(
- "No converter registered for op: " + op);
- }
- OpConverter op_converter = op_registry_.at(op);
std::vector<TRT_TensorOrWeights> outputs;
- TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
+ if (PluginFactoryTensorRT::GetInstance().IsPlugin(&op)) {
+ TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs));
+ } else {
+ if (!op_registry_.count(op)) {
+ return tensorflow::errors::Unimplemented(
+ "No converter registered for op: " + op);
+ }
+ OpConverter op_converter = op_registry_.at(op);
+ TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
+ }
for (size_t i = 0; i < outputs.size(); ++i) {
TRT_TensorOrWeights output = outputs.at(i);
// TODO(jie): tf protobuf seems to be omitting the :0 suffix
@@ -1158,9 +1179,9 @@ tensorflow::Status BinaryTensorOpTensor(
CHECK_EQ_TYPE(tensor_r->getType(), dtype);
auto op_pair = ops.find(node_def.op());
if (op_pair == ops.end())
- return tensorflow::errors::Unimplemented(
- "binary op: " + node_def.op() +
- " not supported at: " + node_def.name());
+ return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
+ " not supported at: " +
+ node_def.name());
nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
*const_cast<nvinfer1::ITensor*>(tensor_l),
@@ -1173,6 +1194,43 @@ tensorflow::Status BinaryTensorOpTensor(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertPlugin(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ // prepare input
+ std::vector<nvinfer1::ITensor*> all_inputs;
+ for (auto input : inputs) {
+ all_inputs.emplace_back(const_cast<nvinfer1::ITensor*>(input.tensor()));
+ }
+
+ // plugin is owned by PluginFactory
+ // TODO(jie): destroy plugins later (resource management)
+ PluginTensorRT* plugin =
+ PluginFactoryTensorRT::GetInstance().CreatePlugin(&node_def.op());
+
+ // passing attributes
+ // TODO(jie): support more general attribute
+ TFAttrs attrs(node_def);
+ auto attr_key_vector = attrs.GetAllAttrKey();
+ for (auto attr_key : attr_key_vector) {
+ std::cout << attr_key << std::endl;
+ // TODO(jie): support only list of float for toy example here.
+ auto data = attrs.get<std::vector<float>>(attr_key);
+ size_t size_data = data.size() * sizeof(float);
+ plugin->SetAttribute(attr_key, static_cast<void*>(data.data()), size_data);
+ }
+
+ nvinfer1::IPluginLayer* layer =
+ ctx.network()->addPlugin(&all_inputs[0], int(inputs.size()), *plugin);
+
+ for (int i = 0; i < layer->getNbOutputs(); i++) {
+ nvinfer1::ITensor* output_tensor = layer->getOutput(i);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ }
+ return tensorflow::Status::OK();
+}
+
tensorflow::Status ConvertPlaceholder(
Converter& ctx, const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
@@ -2073,6 +2131,8 @@ void Converter::register_op_converters() {
op_registry_["Reshape"] = ConvertReshape;
op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm;
op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm;
+
+ plugin_converter_ = ConvertPlugin;
}
} // namespace
@@ -2511,7 +2571,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
std::vector<string> input_names;
std::vector<tensorflow::DataType> input_dtypes;
for (const std::pair<int, int>& input : s.input_inds) {
- VLOG(2) << "parsing input. Node id= " << input.first ;
+ VLOG(2) << "parsing input. Node id= " << input.first;
int node_id = input.first;
int output_idx = input.second;
tensorflow::Node* node = s.graph.FindNodeId(node_id);
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index b32371b642..8881c48fe6 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/core/platform/logging.h"
@@ -58,7 +59,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
IRuntime* infer = nvinfer1::createInferRuntime(logger);
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
- serialized_engine.c_str(), serialized_engine.size(), nullptr));
+ serialized_engine.c_str(), serialized_engine.size(),
+ &PluginFactoryTensorRT::GetInstance()));
trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
// Runtime is safe to delete after engine creation
infer->destroy();
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc
new file mode 100644
index 0000000000..0e4a157d79
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc
@@ -0,0 +1,89 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
+#include <cassert>
+#include <cstring>
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+
+PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) {
+ // sanity check.
+ assert(EncodeOpName(GetPluginName()) !=
+ *static_cast<size_t*>(serialized_data));
+ const char* buffer = static_cast<const char*>(serialized_data) +
+ sizeof(input_dim_list_.size());
+
+ size_t count = *reinterpret_cast<const size_t*>(buffer);
+ buffer += sizeof(size_t);
+
+ for (int i = 0; i < count; i++) {
+ nvinfer1::Dims dim;
+ std::memcpy(&(dim.nbDims), buffer, sizeof(dim.nbDims));
+ buffer += sizeof(dim.nbDims);
+ std::memcpy(dim.d, buffer, sizeof(dim.d));
+ buffer += sizeof(dim.d);
+ std::memcpy(dim.type, buffer, sizeof(dim.type));
+ buffer += sizeof(dim.type);
+ input_dim_list_.emplace_back(dim);
+ }
+}
+
+size_t PluginTensorRT::getSerializationSize() {
+ nvinfer1::Dims dim;
+ return sizeof(size_t) + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) +
+ sizeof(dim.d) + sizeof(dim.type);
+}
+
+void PluginTensorRT::serialize(void* serialized_data) {
+ size_t encode_op_name = EncodeOpName(GetPluginName());
+ char* buffer = static_cast<char*>(serialized_data);
+ std::memcpy(buffer, &encode_op_name, sizeof(size_t));
+ buffer += sizeof(size_t);
+
+ auto list_size = input_dim_list_.size();
+ std::memcpy(buffer, &list_size, sizeof(input_dim_list_.size()));
+ buffer += sizeof(input_dim_list_.size());
+
+ for (int i = 0; i < input_dim_list_.size(); i++) {
+ auto dim = input_dim_list_[i];
+ std::memcpy(buffer, &(dim.nbDims), sizeof(dim.nbDims));
+ buffer += sizeof(dim.nbDims);
+ std::memcpy(buffer, dim.d, sizeof(dim.d));
+ buffer += sizeof(dim.d);
+ std::memcpy(buffer, dim.type, sizeof(dim.type));
+ buffer += sizeof(dim.type);
+ }
+}
+
+bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr,
+ const size_t size) {
+ if (attr_map_.count(key) != 0) return false;
+
+ attr_map_.emplace(key, std::vector<char>(size));
+ std::memcpy(attr_map_[key].data(), ptr, size);
+ return true;
+}
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h
new file mode 100644
index 0000000000..1bbfe62a4e
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN
+#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN
+
+#include <iostream>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+using std::string;
+using std::unordered_map;
+
+class PluginTensorRT : public nvinfer1::IPlugin {
+ public:
+ PluginTensorRT(){};
+ PluginTensorRT(const void* serialized_data, size_t length);
+ // PluginTensorRT(const void* serialized_data, size_t length, size_t
+ // &incremental);
+ virtual string GetPluginName() = 0;
+ virtual bool Finalize() = 0;
+
+ virtual bool SetAttribute(const string& key, const void* ptr,
+ const size_t size) = 0;
+ virtual bool GetAttribute(const string& key, const void* ptr,
+ size_t& size) = 0;
+
+ void configure(const nvinfer1::Dims* inputs, int nbInputs,
+ const nvinfer1::Dims* outputs, int nbOutputs,
+ int maxBatchSize) override {
+ for (int index = 0; index < nbInputs; index++) {
+ nvinfer1::Dims dim;
+ dim.nbDims = inputs[index].nbDims;
+ for (int i = 0; i < dim.nbDims; i++) {
+ dim.d[i] = inputs[index].d[i];
+ dim.type[i] = inputs[index].type[i];
+ }
+ input_dim_list_.emplace_back(dim);
+ }
+ return;
+ }
+
+ virtual bool StoreAttribute(const string& key, const void* ptr,
+ const size_t size);
+
+ virtual size_t getSerializationSize() override;
+ virtual void serialize(void* buffer) override;
+
+ protected:
+ std::unordered_map<string, std::vector<char> > attr_map_;
+
+ std::vector<nvinfer1::Dims> input_dim_list_;
+};
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc
new file mode 100644
index 0000000000..799c609a3e
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+
+PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layerName,
+ const void* serial_data,
+ size_t serial_length) {
+ size_t parsed_byte = 0;
+ // extract op_name from serial_data
+ size_t encoded_op_name =
+ ExtractOpName(serial_data, serial_length, parsed_byte);
+
+ if (!IsPlugin(encoded_op_name)) {
+ return nullptr;
+ }
+
+ // should I lock plugins here?
+ instance_m_.lock();
+ auto plugin_ptr =
+ plugin_registry_[encoded_op_name].first(serial_data, serial_length);
+ // string op_name = "IncPluginTRT";
+ // auto plugin_ptr = plugin_registry_[EncodeLayerName(&op_name)].second();
+ // auto plugin_ptr = plugin_registry_.begin()->second.second();
+ owned_plugins_.emplace_back(plugin_ptr);
+ instance_m_.unlock();
+
+ return plugin_ptr;
+}
+
+PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string* op_name) {
+ if (!IsPlugin(op_name)) return nullptr;
+
+ instance_m_.lock();
+ auto plugin_ptr = plugin_registry_[EncodeLayerName(op_name)].second();
+ owned_plugins_.emplace_back(plugin_ptr);
+ instance_m_.unlock();
+
+ return plugin_ptr;
+}
+
+bool PluginFactoryTensorRT::RegisterPlugin(
+ const string* op_name, PluginDeserializeFunc deserialize_func,
+ PluginConstructFunc construct_func) {
+ if (IsPlugin(op_name)) return false;
+
+ // get instance_m_ first before write to registry;
+ instance_m_.lock();
+ auto ret = plugin_registry_.emplace(
+ EncodeLayerName(op_name),
+ std::make_pair(deserialize_func, construct_func));
+ instance_m_.unlock();
+
+ return ret.second;
+}
+
+void PluginFactoryTensorRT::DestroyPlugins() { return; }
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h
new file mode 100644
index 0000000000..e68f4629d0
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY
+#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY
+
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+#include "trt_plugin.h"
+#include "trt_plugin_utils.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
+ public:
+ // deserialization method
+ // virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const void*
+ // serialData, size_t serialLength) override;
+ PluginTensorRT* createPlugin(const char* layerName, const void* serialData,
+ size_t serialLength) override;
+
+ // construction
+ PluginTensorRT* CreatePlugin(const string* op_name);
+
+ static PluginFactoryTensorRT& GetInstance() {
+ static PluginFactoryTensorRT factory_instance;
+ return factory_instance;
+ }
+
+ bool RegisterPlugin(const string* op_name,
+ PluginDeserializeFunc deserialize_func,
+ PluginConstructFunc construct_func);
+
+ bool IsPlugin(const size_t encode_name) {
+ return plugin_registry_.find(encode_name) != plugin_registry_.end();
+ }
+
+ bool IsPlugin(const string* op_name) {
+ return IsPlugin(EncodeLayerName(op_name));
+ }
+
+ size_t EncodeLayerName(const string* op_name) {
+ return EncodeOpName(*op_name);
+ }
+
+ void DestroyPlugins();
+
+ protected:
+ std::unordered_map<size_t,
+ std::pair<PluginDeserializeFunc, PluginConstructFunc> >
+ plugin_registry_;
+
+ // TODO(jie): Owned plugin should be associated with different sessions;
+ // should really hand ownership of plugins to resource management;
+ std::vector<std::unique_ptr<PluginTensorRT> > owned_plugins_;
+ std::mutex instance_m_;
+};
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc
new file mode 100644
index 0000000000..b14480cfa6
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+
+size_t ExtractOpName(const void* serial_data, size_t serial_length,
+ size_t& incremental) {
+ incremental = sizeof(size_t);
+ if (serial_length < incremental) return 0;
+ size_t encoded_op_name = *static_cast<const size_t*>(serial_data);
+ return encoded_op_name;
+}
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h
new file mode 100644
index 0000000000..e9675d84cd
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS
+#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS
+
+#include <functional>
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+typedef std::function<PluginTensorRT*(const void*, size_t)>
+ PluginDeserializeFunc;
+
+typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;
+
+inline size_t EncodeOpName(std::string str) {
+ return std::hash<std::string>{}(str);
+}
+
+// TODO(jie): work on error handling here
+size_t ExtractOpName(const void* serial_data, size_t serial_length,
+ size_t& incremental);
+
+// size_t Deserialize(const char* serial_data, size_t serial_length, size_t
+// &incremental);
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
index 8b475177bc..30b5616475 100644
--- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include <string>
#include <vector>
@@ -33,7 +34,8 @@ tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) {
TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine));
nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
- serialized_engine.c_str(), serialized_engine.size(), nullptr);
+ serialized_engine.c_str(), serialized_engine.size(),
+ &tensorrt::PluginFactoryTensorRT::GetInstance());
int num_batch = -1;
std::vector<::tensorflow::DataType> input_type;