aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/captured_function.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/captured_function.cc')
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc66
1 files changed, 48 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index ad2365b25b..31c8f5c0ea 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/random/random.h"
@@ -358,7 +359,8 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args,
void CapturedFunction::RunAsync(IteratorContext* ctx,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done) {
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix) {
// NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
// be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function.
@@ -391,23 +393,51 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// will be required to plumb it through the `IteratorContext`.
auto c_mgr = new CancellationManager;
f_opts.cancellation_manager = c_mgr;
-
- tf_shared_lock l(mu_);
- ctx->lib()->Run(f_opts, handle, frame,
- std::bind(
- [rets, step_container, c_mgr, frame](
- FunctionLibraryRuntime::DoneCallback done,
- // Begin unbound arguments.
- Status s) {
- delete step_container;
- delete c_mgr;
- if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
- }
- delete frame;
- done(s);
- },
- std::move(done), std::placeholders::_1));
+ StepStats* stats = nullptr;
+ StepStatsCollector* stats_collector = nullptr;
+ std::shared_ptr<model::Node> node;
+ if (ctx->model()) {
+ node = ctx->model()->LookupNode(prefix);
+ if (node) {
+ // TODO(b/114104975): Use something light-weight here.
+ stats = new StepStats();
+ stats_collector = new StepStatsCollector(stats);
+ }
+ }
+ f_opts.stats_collector = stats_collector;
+
+ auto callback = std::bind(
+ [rets, step_container, c_mgr, frame, stats, stats_collector, node](
+ FunctionLibraryRuntime::DoneCallback done,
+ // Begin unbound arguments.
+ Status s) {
+ delete step_container;
+ delete c_mgr;
+ if (s.ok()) {
+ s = frame->ConsumeRetvals(rets);
+ }
+ delete frame;
+ if (node) {
+ int64 delta = 0;
+ stats_collector->Finalize();
+ for (auto dev_stats : stats->dev_stats()) {
+ for (auto node_stats : dev_stats.node_stats()) {
+ delta += node_stats.all_end_rel_nanos();
+ }
+ }
+ delete stats_collector;
+ delete stats;
+ node->add_processing_time(delta);
+ node->start_work();
+ }
+ done(s);
+ if (node) {
+ node->stop_work();
+ }
+ },
+ std::move(done), std::placeholders::_1);
+
+ ctx->lib()->Run(f_opts, handle, frame, std::move(callback));
}
CapturedFunction::CapturedFunction(const NameAttrList& func,