diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc')
-rw-r--r-- | tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 6851f79ef6..2b42d81f47 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -173,7 +174,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, helper->Ref(); // Increment count for calculating native graph VLOG(1) << "Executing native segment " << name(); lib->Run(opts, native_func_, inputs, outputs, - [ctx, outputs, helper](const tensorflow::Status& s) { + [this, ctx, outputs, helper](const tensorflow::Status& s) { tensorflow::core::ScopedUnref sc(helper); VLOG(1) << "Native Segment completed"; if (!s.ok()) { @@ -183,6 +184,8 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, for (size_t t = 0; t < outputs->size(); ++t) { ctx->set_output(t, outputs->at(t)); } + test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"), + "done"); delete outputs; }); } @@ -228,6 +231,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, ->implementation() ->GpuStreamMemberHack())); calib_res->calibrator_->setBatch(input_data, *stream); + test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done"); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); } @@ -252,7 +256,7 @@ int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) { StrCat("Engine buffer is full. buffer limit=", max_cached_engines_, ", current entries="); for (auto i : cached_engine_batches_) StrAppend(&msg, i, ","); - StrAppend(&msg, "Requested batch=", num_batch); + StrAppend(&msg, " requested batch=", num_batch); LOG(WARNING) << msg; return -1; } @@ -270,7 +274,8 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, } const int smallest_engine = GetEngineBatch(ctx); if (smallest_engine < 0) { - LOG(WARNING) << "Failed to get engine batch, running native segment"; + LOG(WARNING) << "Failed to get engine batch, running native segment for " + << name(); ExecuteNativeSegment(ctx, helper); return; } @@ -280,14 +285,15 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, auto& trt_engine_ptr = engine_ctx_pair.first; if (!trt_engine_ptr) { LOG(WARNING) << "Engine retrieval for batch size " << num_batch - << " failed. Running native segment"; + << " failed. Running native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(), engine_ctx_pair.second.get()); if (retry) { - LOG(WARNING) << "Failed to execute engine, retrying with native segment"; + LOG(WARNING) << "Failed to execute engine, " + << "retrying with native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } @@ -406,6 +412,7 @@ bool TRTEngineOp::ExecuteTrtEngine( LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); return kRetry; } + test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done"); // Synchronization will be done by TF. return !kRetry; } |