diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-19 19:14:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-19 19:14:08 -0700 |
commit | 3a56de02398b61226924955c6ee4297a4ecb5d45 (patch) | |
tree | aff45cd43e2d4ac57b9d77104b368ae86851dabb | |
parent | 6155633318d648ce6ad9567ad5163bd9fc763454 (diff) | |
parent | c519794c7cca51d2c75aa53b56a1448804f68647 (diff) |
Merge pull request #20969 from aaroey:fix_plugin_test
PiperOrigin-RevId: 205340359
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc | 5 |
3 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 65fef27533..49e825151a 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -2588,6 +2588,8 @@ void Converter::register_op_converters() { op_registry_["BatchMatMul"] = ConvertBatchMatMul; op_registry_["TopKV2"] = ConvertTopK; #endif + + plugin_converter_ = ConvertPlugin; } } // namespace diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index a89cf3ab8b..69058c5826 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -112,7 +112,9 @@ cuda_py_test( ], tags = [ "manual", + "no_windows", "noguitar", + "nomac", "notap", ], ) diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 54009179a8..646d62483f 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" #include <algorithm> + #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -457,7 +459,8 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, #endif TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine( infer->deserializeCudaEngine(serialized_segment_.c_str(), - serialized_segment_.size(), nullptr)); + serialized_segment_.size(), + PluginFactoryTensorRT::GetInstance())); auto raw_static_engine = static_engine.get(); const auto max_batch_size = raw_static_engine->getMaxBatchSize(); engine_map_[max_batch_size] = { |