aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc')
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc20
1 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 8a17eb02f1..646d62483f 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -15,9 +15,11 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
#include <algorithm>
+
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
@@ -230,7 +232,7 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
- ->CudaStreamMemberHack()));
+ ->GpuStreamMemberHack()));
calib_res->calibrator_->setBatch(input_data, *stream);
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
@@ -316,6 +318,11 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
ctx->SetStatus(tensorflow::errors::InvalidArgument(
"INT8 inputs are not supported!"));
return;
+#if NV_TENSORRT_MAJOR > 3
+ case nvinfer1::DataType::kINT32:
+ buffers[binding_index] = (void*)(input_tensor.flat<int32>().data());
+ break;
+#endif
default:
LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
ctx->SetStatus(tensorflow::errors::InvalidArgument(
@@ -368,6 +375,12 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
ctx->SetStatus(tensorflow::errors::InvalidArgument(
"INT8 outputs are not supported!"));
return;
+#if NV_TENSORRT_MAJOR > 3
+ case nvinfer1::DataType::kINT32:
+ buffers[binding_index] =
+ reinterpret_cast<void*>(output_tensor->flat<int32>().data());
+ break;
+#endif
default:
LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype);
ctx->SetStatus(tensorflow::errors::InvalidArgument(
@@ -380,7 +393,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
- ->CudaStreamMemberHack()));
+ ->GpuStreamMemberHack()));
// TODO(jie): trt enqueue does not return error
auto& trt_execution_context_ptr = engine_ctx_pair.second;
@@ -446,7 +459,8 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
#endif
TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
infer->deserializeCudaEngine(serialized_segment_.c_str(),
- serialized_segment_.size(), nullptr));
+ serialized_segment_.size(),
+ PluginFactoryTensorRT::GetInstance()));
auto raw_static_engine = static_engine.get();
const auto max_batch_size = raw_static_engine->getMaxBatchSize();
engine_map_[max_batch_size] = {