aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-19 19:14:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-19 19:14:08 -0700
commit3a56de02398b61226924955c6ee4297a4ecb5d45 (patch)
treeaff45cd43e2d4ac57b9d77104b368ae86851dabb
parent6155633318d648ce6ad9567ad5163bd9fc763454 (diff)
parentc519794c7cca51d2c75aa53b56a1448804f68647 (diff)
Merge pull request #20969 from aaroey:fix_plugin_test
PiperOrigin-RevId: 205340359
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc2
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD2
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc5
3 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 65fef27533..49e825151a 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -2588,6 +2588,8 @@ void Converter::register_op_converters() {
op_registry_["BatchMatMul"] = ConvertBatchMatMul;
op_registry_["TopKV2"] = ConvertTopK;
#endif
+
+ plugin_converter_ = ConvertPlugin;
}
} // namespace
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD
index a89cf3ab8b..69058c5826 100644
--- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD
@@ -112,7 +112,9 @@ cuda_py_test(
],
tags = [
"manual",
+ "no_windows",
"noguitar",
+ "nomac",
"notap",
],
)
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 54009179a8..646d62483f 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -15,9 +15,11 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
#include <algorithm>
+
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
@@ -457,7 +459,8 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
#endif
TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
infer->deserializeCudaEngine(serialized_segment_.c_str(),
- serialized_segment_.size(), nullptr));
+ serialized_segment_.size(),
+ PluginFactoryTensorRT::GetInstance()));
auto raw_static_engine = static_engine.get();
const auto max_batch_size = raw_static_engine->getMaxBatchSize();
engine_map_[max_batch_size] = {