aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/kernels
diff options
context:
space:
mode:
authorGravatar Wen-Heng (Jack) Chung <whchung@gmail.com>2018-07-11 17:57:38 +0000
committerGravatar Wen-Heng (Jack) Chung <whchung@gmail.com>2018-07-12 17:14:09 +0000
commit25021d386cd989aedde11b72c5db36b7c1bfd2b4 (patch)
tree8c325d003489372b16747a4cbe5dec4fe3f276f1 /tensorflow/contrib/tensorrt/kernels
parent135e419e780423a888ddd45e479129493336c52b (diff)
[ROCm] Interface changes for StreamExecutor to support both CUDA and ROCm
1) StreamInterface::CudaStreamMemberHack() Despite the fact that StreamExecutor and GPU common runtime are largely orthogonal, this particular routine in StreamExecutor is used in GPU common runtime and a couple of other operators. In this commit it's renamed as StreamInterface::GpuStreamMemberHack() and their call sites are also changed. 2) StreamExecutorInterface::CudaContextHack() This member is renamed to StramExecutorInterface::GpuContextHack(). Changes introduced in this commit includes: - some StreamExecutor interfaces and CUDA implementation - GPU common runtime related to interface changes in StreamExecutor - operators affected by interface changes in StreamExecutor
Diffstat (limited to 'tensorflow/contrib/tensorrt/kernels')
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 8a17eb02f1..3daf810a4b 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -230,7 +230,7 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
- ->CudaStreamMemberHack()));
+ ->GpuStreamMemberHack()));
calib_res->calibrator_->setBatch(input_data, *stream);
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
@@ -380,7 +380,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
- ->CudaStreamMemberHack()));
+ ->GpuStreamMemberHack()));
// TODO(jie): trt enqueue does not return error
auto& trt_execution_context_ptr = engine_ctx_pair.second;