diff options
author | 2018-07-11 17:57:38 +0000 | |
---|---|---|
committer | 2018-07-12 17:14:09 +0000 | |
commit | 25021d386cd989aedde11b72c5db36b7c1bfd2b4 (patch) | |
tree | 8c325d003489372b16747a4cbe5dec4fe3f276f1 /tensorflow/contrib/tensorrt/kernels | |
parent | 135e419e780423a888ddd45e479129493336c52b (diff) |
[ROCm] Interface changes for StreamExecutor to support both CUDA and ROCm
1) StreamInterface::CudaStreamMemberHack()
Despite the fact that StreamExecutor and GPU common runtime are largely
orthogonal, this particular routine in StreamExecutor is used in GPU common
runtime and a couple of other operators. In this commit it's renamed as
StreamInterface::GpuStreamMemberHack() and their call sites are also changed.
2) StreamExecutorInterface::CudaContextHack()
This member is renamed to StramExecutorInterface::GpuContextHack().
Changes introduced in this commit includes:
- some StreamExecutor interfaces and CUDA implementation
- GPU common runtime related to interface changes in StreamExecutor
- operators affected by interface changes in StreamExecutor
Diffstat (limited to 'tensorflow/contrib/tensorrt/kernels')
-rw-r--r-- | tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 8a17eb02f1..3daf810a4b 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -230,7 +230,7 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, reinterpret_cast<const cudaStream_t*>(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); calib_res->calibrator_->setBatch(input_data, *stream); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); @@ -380,7 +380,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, reinterpret_cast<const cudaStream_t*>(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); // TODO(jie): trt enqueue does not return error auto& trt_execution_context_ptr = engine_ctx_pair.second; |