aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc')
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc17
1 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 6851f79ef6..2b42d81f47 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -173,7 +174,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment " << name();
lib->Run(opts, native_func_, inputs, outputs,
- [ctx, outputs, helper](const tensorflow::Status& s) {
+ [this, ctx, outputs, helper](const tensorflow::Status& s) {
tensorflow::core::ScopedUnref sc(helper);
VLOG(1) << "Native Segment completed";
if (!s.ok()) {
@@ -183,6 +184,8 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
for (size_t t = 0; t < outputs->size(); ++t) {
ctx->set_output(t, outputs->at(t));
}
+ test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"),
+ "done");
delete outputs;
});
}
@@ -228,6 +231,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
->implementation()
->GpuStreamMemberHack()));
calib_res->calibrator_->setBatch(input_data, *stream);
+ test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done");
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
}
@@ -252,7 +256,7 @@ int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) {
StrCat("Engine buffer is full. buffer limit=", max_cached_engines_,
", current entries=");
for (auto i : cached_engine_batches_) StrAppend(&msg, i, ",");
- StrAppend(&msg, "Requested batch=", num_batch);
+ StrAppend(&msg, " requested batch=", num_batch);
LOG(WARNING) << msg;
return -1;
}
@@ -270,7 +274,8 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
}
const int smallest_engine = GetEngineBatch(ctx);
if (smallest_engine < 0) {
- LOG(WARNING) << "Failed to get engine batch, running native segment";
+ LOG(WARNING) << "Failed to get engine batch, running native segment for "
+ << name();
ExecuteNativeSegment(ctx, helper);
return;
}
@@ -280,14 +285,15 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
auto& trt_engine_ptr = engine_ctx_pair.first;
if (!trt_engine_ptr) {
LOG(WARNING) << "Engine retrieval for batch size " << num_batch
- << " failed. Running native segment";
+ << " failed. Running native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(),
engine_ctx_pair.second.get());
if (retry) {
- LOG(WARNING) << "Failed to execute engine, retrying with native segment";
+ LOG(WARNING) << "Failed to execute engine, "
+ << "retrying with native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
@@ -406,6 +412,7 @@ bool TRTEngineOp::ExecuteTrtEngine(
LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name();
return kRetry;
}
+ test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done");
// Synchronization will be done by TF.
return !kRetry;
}