aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc')
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc11
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc
index 1dcb87e768..aea44fd8a2 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc
@@ -21,10 +21,11 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/stream_executor.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
-#include "cuda_runtime_api.h"
+#include "cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
@@ -113,7 +114,13 @@ void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) {
ctx->set_output(i, t);
}
VLOG(2) << "Filled map for sending";
- calib_res->calibrator_->setBatch(input_data);
+ // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
+ const cudaStream_t* stream = CHECK_NOTNULL(
+ reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
+ ->stream()
+ ->implementation()
+ ->CudaStreamMemberHack()));
+ calib_res->calibrator_->setBatch(input_data, *stream);
VLOG(2) << "Passed calibration data";
// TODO(aaroey): make sure we wait for the completion of calibration on the
// last batch in future PR.