diff options
Diffstat (limited to 'tensorflow/core/kernels/data/captured_function.cc')
-rw-r--r-- | tensorflow/core/kernels/data/captured_function.cc | 66 |
1 files changed, 48 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index ad2365b25b..31c8f5c0ea 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -17,6 +17,7 @@ limitations under the License. #include <utility> #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/random/random.h" @@ -358,7 +359,8 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args, void CapturedFunction::RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets, - FunctionLibraryRuntime::DoneCallback done) { + FunctionLibraryRuntime::DoneCallback done, + const string& prefix) { // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may // be deleted before `done` is called. Take care not to capture `ctx` in any // code that may execute asynchronously in this function. @@ -391,23 +393,51 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, // will be required to plumb it through the `IteratorContext`. auto c_mgr = new CancellationManager; f_opts.cancellation_manager = c_mgr; - - tf_shared_lock l(mu_); - ctx->lib()->Run(f_opts, handle, frame, - std::bind( - [rets, step_container, c_mgr, frame]( - FunctionLibraryRuntime::DoneCallback done, - // Begin unbound arguments. - Status s) { - delete step_container; - delete c_mgr; - if (s.ok()) { - s = frame->ConsumeRetvals(rets); - } - delete frame; - done(s); - }, - std::move(done), std::placeholders::_1)); + StepStats* stats = nullptr; + StepStatsCollector* stats_collector = nullptr; + std::shared_ptr<model::Node> node; + if (ctx->model()) { + node = ctx->model()->LookupNode(prefix); + if (node) { + // TODO(b/114104975): Use something light-weight here. + stats = new StepStats(); + stats_collector = new StepStatsCollector(stats); + } + } + f_opts.stats_collector = stats_collector; + + auto callback = std::bind( + [rets, step_container, c_mgr, frame, stats, stats_collector, node]( + FunctionLibraryRuntime::DoneCallback done, + // Begin unbound arguments. + Status s) { + delete step_container; + delete c_mgr; + if (s.ok()) { + s = frame->ConsumeRetvals(rets); + } + delete frame; + if (node) { + int64 delta = 0; + stats_collector->Finalize(); + for (auto dev_stats : stats->dev_stats()) { + for (auto node_stats : dev_stats.node_stats()) { + delta += node_stats.all_end_rel_nanos(); + } + } + delete stats_collector; + delete stats; + node->add_processing_time(delta); + node->start_work(); + } + done(s); + if (node) { + node->stop_work(); + } + }, + std::move(done), std::placeholders::_1); + + ctx->lib()->Run(f_opts, handle, frame, std::move(callback)); } CapturedFunction::CapturedFunction(const NameAttrList& func, |