aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/eager/execute.cc
blob: 181b222b4c90dbca25b0b4ebdf0f5297ed5a9662 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/common_runtime/eager/execute.h"

#include <vector>

#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
#include "tensorflow/core/common_runtime/eager/execute_node.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#ifndef __ANDROID__
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
#endif
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/ptr_util.h"

namespace tensorflow {

namespace {

// Copy of the definition in third_party/tensorflow/compiler/jit/defs.h
// Copied here because we don't currently compile XLA on windows. So, can't
// depend on it directly.
const char* const kXlaCompileAttr = "_XlaCompile";

// Initializes the step stats if needed.
void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
  // Lazily initialize the RunMetadata with information about all devices if
  // this is the first call.
  while (step_stats->dev_stats_size() < ctx->devices()->size()) {
    int device_idx = step_stats->dev_stats_size();
    auto* dev_stats = step_stats->add_dev_stats();
    dev_stats->set_device(ctx->devices()->at(device_idx)->name());
  }
}

int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
                         Device* device) {
  // Find the current device's index.
  if (device == nullptr) {
    device = ctx->HostCPU();
  }
  for (int i = 0; i < ctx->devices()->size(); ++i) {
    if (ctx->devices()->at(i) == device ||
        ctx->devices()->at(i)->name() == device->name()) {
      return i;
    }
  }
  // TODO(apassos) do not fall back to host CPU if device is unknown.
  return 0;
}

// This function expects *handle to point to an existing tensor handle. The
// function will (maybe) update the *handle to be pointed to the newly copied
// tensor handle.
//
// The passed in *handle will be Unreffed if it is replaced.
Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
                                      const Device* expected_device,
                                      RunMetadata* run_metadata,
                                      TensorHandle** handle) {
  EagerContext* ctx = op->EagerContext();
  Device* handle_device = nullptr;
  TF_RETURN_IF_ERROR((*handle)->Device(&handle_device));
  const Device* actual_device =
      handle_device == nullptr ? ctx->HostCPU() : handle_device;
  const Device* op_device =
      op->Device() == nullptr ? ctx->HostCPU() : op->Device();

  if (expected_device != actual_device) {
    switch (ctx->GetDevicePlacementPolicy()) {
      case DEVICE_PLACEMENT_SILENT_FOR_INT32:
        // TODO(xpan): See if we could bubble python related error up
        // to python level.
        if ((*handle)->dtype == DT_INT32) {
          // Note: enabling silent copies of int32 tensors to match behavior
          // of graph mode.
          break;
        }
        TF_FALLTHROUGH_INTENDED;
      case DEVICE_PLACEMENT_EXPLICIT:
        return errors::InvalidArgument(
            "Tensors on conflicting devices:"
            " cannot compute ",
            op->Name(), " as input #", i, " was expected to be on ",
            expected_device->name(), " but is actually on ",
            actual_device->name(), " (operation running on ", op_device->name(),
            ")",
            " Tensors can be copied explicitly using .gpu() or .cpu() "
            "methods,"
            " or transparently copied by using tf.enable_eager_execution("
            "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
            "between devices"
            " may slow down your model");
      case DEVICE_PLACEMENT_WARN:
        LOG(WARNING) << "before computing " << op->Name() << " input #" << i
                     << " was expected to be on " << expected_device->name()
                     << " but is actually on " << actual_device->name()
                     << " (operation running on " << op_device->name()
                     << "). This triggers a copy which can be a performance "
                        "bottleneck.";
        break;
      case DEVICE_PLACEMENT_SILENT:  // Do nothing.
        break;
    }
    // We are only here if the policy is warn or silent copies, so we should
    // trigger a copy.
    auto pre_time_nanos = Env::Default()->NowNanos();
    TensorHandle* result_handle = nullptr;
    Status status = EagerCopyToDevice(
        *handle, ctx, expected_device->name().c_str(), &result_handle);
    if (run_metadata != nullptr) {
      auto* step_stats = run_metadata->mutable_step_stats();
      MaybeInitializeStepStats(step_stats, ctx);
      // Record the sending on the source device for now.
      int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device);
      auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
      auto* node_stats = dev_stats->add_node_stats();
      node_stats->set_node_name("_Send");
      node_stats->set_all_start_micros(pre_time_nanos /
                                       EnvTime::kMicrosToNanos);
      node_stats->set_all_start_nanos(pre_time_nanos);
      int64 now_nanos = Env::Default()->NowNanos();
      node_stats->set_op_end_rel_micros((now_nanos - pre_time_nanos) /
                                        EnvTime::kMicrosToNanos);
      node_stats->set_op_end_rel_nanos(now_nanos - pre_time_nanos);
    }
    if (!status.ok()) {
      if (result_handle != nullptr) result_handle->Unref();
      return errors::Internal("Failed copying input tensor from ",
                              actual_device->name(), " to ",
                              expected_device->name(), " in order to run ",
                              op->Name(), ": ", status.error_message());
    }

    (*handle)->Unref();
    *handle = result_handle;
  }
  return Status::OK();
}

Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
                                     EagerOperation* op, const OpKernel* kernel,
                                     RunMetadata* run_metadata) {
  Device* host_device = ctx->HostCPU();
  const MemoryTypeVector& memtypes = kernel->input_memory_types();
  if (memtypes.size() != op->Inputs().size()) {
    return errors::InvalidArgument("expected ", memtypes.size(),
                                   " inputs, got ", op->Inputs().size());
  }
  for (int i = 0; i < op->Inputs().size(); ++i) {
    const Device* expected_device =
        memtypes[i] == HOST_MEMORY ? host_device : op_device;
    TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
        op, i, expected_device, run_metadata, &((*op->MutableInputs())[i])));
    tensorflow::TensorHandle* handle = op->Inputs()[i];
    if (handle->dtype != kernel->input_type(i)) {
      return errors::InvalidArgument(
          "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
          " was expected to be a ", DataTypeString(kernel->input_type(i)),
          " tensor but is a ", DataTypeString(handle->dtype), " tensor");
    }
  }
  return Status::OK();
}

Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
  DeviceSet ds;
  for (Device* d : *ctx->devices()) {
    ds.AddDevice(d);
  }
  DeviceTypeVector final_devices;
  auto status = SupportedDeviceTypesForNode(ds.PrioritizedDeviceTypeList(),
                                            ndef, &final_devices);
  if (!status.ok()) return status;
  if (final_devices.empty()) {
    return errors::Internal("Could not find valid device for node ",
                            ndef.DebugString());
  }
  for (Device* d : *ctx->devices()) {
    if (d->device_type() == final_devices[0].type_string()) {
      *device = d;
      return Status::OK();
    }
  }
  return errors::Unknown("Could not find a device for node ",
                         ndef.DebugString());
}

Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
  const auto& node_def = op->MutableAttrs()->BuildNodeDef();
  const OpDef* op_def = nullptr;

  const FunctionDef* function_def =
      op->EagerContext()->FuncLibDef()->Find(op->Name());
  if (function_def != nullptr) {
    op_def = &(function_def->signature());
  } else {
    TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
  }

  TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));

  return Status::OK();
}

}  // namespace

namespace {
bool IsLocal(EagerContext* ctx, tensorflow::Device* d) {
  if (d == nullptr || ctx->remote_device_mgr() == nullptr) return true;
  tensorflow::Device* tmp;
  return ctx->local_device_mgr()->LookupDevice(d->name(), &tmp).ok();
}

bool OnSameTask(EagerContext* ctx, Device* first, Device* second) {
  if (first == nullptr) first = ctx->HostCPU();
  if (second == nullptr) second = ctx->HostCPU();
  return first->parsed_name().job == second->parsed_name().job &&
         first->parsed_name().replica == second->parsed_name().replica &&
         first->parsed_name().task == second->parsed_name().task;
}

Status EagerLocalExecute(EagerOperation* op,
                         gtl::InlinedVector<TensorHandle*, 2>* retvals,
                         int* num_retvals) {
  EagerContext* ctx = op->EagerContext();
  auto status = ctx->GetStatus();
  if (!status.ok()) return status;
  // Ensure all resource-touching ops run in the device the resource is,
  // regardless of anything else that has been specified. This is identical to
  // the graph mode behavior.
  for (int i = 0; i < op->Inputs().size(); ++i) {
    Device* input_op_device = nullptr;
    status = op->Inputs()[i]->OpDevice(&input_op_device);
    if (!status.ok()) return status;
    VLOG(2) << "for op " << op->Name() << " input " << i << " "
            << DataTypeString(op->Inputs()[i]->dtype) << " "
            << (input_op_device == nullptr ? "cpu" : input_op_device->name())
            << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
    if (op->Inputs()[i]->dtype == DT_RESOURCE &&
        (input_op_device != op->Device() || input_op_device == nullptr)) {
      Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
      VLOG(1) << "Changing device of operation " << op->Name() << " to "
              << d->name() << " because input #" << i
              << " is a resource in this device.";
      op->SetDevice(d);
    }
  }
  Device* device = op->Device();

  Fprint128 cache_key = op->MutableAttrs()->CacheKey(
      device == nullptr ? "unspecified" : device->name());
  KernelAndDevice* kernel = ctx->GetCachedKernel(cache_key);
  if (kernel == nullptr) {
    // If we are running a function on explicitly requested TPU,
    // compile it with XLA.
    // Note that it is not ideal, but currently ok, to set this
    // attribute after computing the kernel cache key above.
    if (op->is_function() && device != nullptr &&
        device->device_type() == "TPU") {
      op->MutableAttrs()->Set(kXlaCompileAttr, true);
    }

    const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
    if (device == nullptr) {
      status = SelectDevice(ndef, ctx, &device);
      if (!status.ok()) return status;
    }
    CHECK(device != nullptr);
    if (ctx->LogDevicePlacement()) {
      LOG(INFO) << "Executing op " << ndef.op() << " in device "
                << device->name();
    }
    kernel = new KernelAndDevice(ctx->GetRendezvous());
    // Knowledge of the implementation of Init (and in-turn
    // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
    // will be accessed, so grab on to the lock.
    // See WARNING comment in Execute (before kernel->Run) - would be nice to
    // rework to avoid this subtlety.
    tf_shared_lock l(*ctx->FunctionsMu());
    status = KernelAndDevice::Init(ndef, ctx->func_lib(device), ctx->runner(),
                                   kernel);
    if (!status.ok()) {
      delete kernel;
      return status;
    }
    // Update output_dtypes inside `kernel`.
    const OpDef* op_def = nullptr;
    const FunctionDef* function_def = ctx->FuncLibDef()->Find(ndef.op());
    if (function_def != nullptr) {
      op_def = &(function_def->signature());
    }
    if (op_def == nullptr) {
      status = OpDefForOp(ndef.op().c_str(), &op_def);
      if (!status.ok()) return status;
    }
    DataTypeVector input_dtypes;
    status = InOutTypesForNode(ndef, *op_def, &input_dtypes,
                               kernel->mutable_output_dtypes());
    if (!status.ok()) return status;
    ctx->AddKernelToCache(cache_key, kernel);
  }
  const DataTypeVector& output_dtypes = kernel->output_dtypes();
  const int output_dtypes_size = static_cast<int>(output_dtypes.size());
  if (output_dtypes_size > *num_retvals) {
    return errors::InvalidArgument("Expecting ", output_dtypes.size(),
                                   " outputs, but *num_retvals is ",
                                   *num_retvals);
  }
  *num_retvals = output_dtypes_size;
  if (device == nullptr) {
    // TODO(apassos) debug how the assignment below might return a different
    // device from the one requested above.
    device = kernel->device();
  }
  status = ValidateInputTypeAndPlacement(
      ctx, device, op, kernel->kernel(),
      ctx->ShouldStoreMetadata() ? ctx->RunMetadataProto() : nullptr);
  if (!status.ok()) return status;
  std::unique_ptr<NodeExecStats> maybe_stats;
  if (ctx->ShouldStoreMetadata()) {
    int64 now_nanos = Env::Default()->NowNanos();
    maybe_stats.reset(new NodeExecStats);
    maybe_stats->set_node_name(op->Name());
    maybe_stats->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
    maybe_stats->set_all_start_nanos(now_nanos);
    maybe_stats->set_op_start_rel_micros(0);
    maybe_stats->set_op_start_rel_nanos(0);
    maybe_stats->set_scheduled_micros(now_nanos / EnvTime::kMicrosToNanos);
    maybe_stats->set_scheduled_nanos(now_nanos);
    // TODO(apassos) track referenced tensors
  }
  retvals->resize(*num_retvals);
  if (ctx->Async()) {
    // Note that for async mode, execution order will make sure that all
    // input handles are ready before executing them.
    // TODO(agarwal): Consider executing "cheap" kernels inline for performance.
    tensorflow::uint64 id = ctx->NextId();
    for (int i = 0; i < *num_retvals; ++i) {
      (*retvals)[i] = new TensorHandle(id, output_dtypes[i], ctx);
    }
    EagerNode* node =
        new ExecuteNode(id, ctx, op->Device(), op->Inputs(), kernel,
                        maybe_stats.release(), output_dtypes, *retvals);
    ctx->ExecutorAdd(node);
  } else {
    // Execute checks if retvals[i] is nullptr or not to figure if it needs to
    // allocate it.
    status = EagerExecute(ctx, op->Device(), op->Inputs(), kernel,
                          maybe_stats.get(), retvals->data(), *num_retvals);
  }

  return status;
}

#ifndef __ANDROID__
std::function<void()> GetRemoteTensorDestructor(
    EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id,
    uint64 op_id, int output_num) {
  return [ctx, eager_client, context_id, op_id, output_num]() {
    std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
    request->set_context_id(context_id);

    auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref();
    handle_to_decref->set_op_id(op_id);
    handle_to_decref->set_output_num(output_num);

    if (ctx->Async()) {
      tensorflow::uint64 id = ctx->NextId();
      auto* node =
          new eager::RemoteExecuteNode(id, std::move(request), eager_client);
      ctx->ExecutorAdd(node);
    } else {
      eager::EnqueueRequest* actual_request = request.release();
      eager::EnqueueResponse* response = new eager::EnqueueResponse;
      eager_client->EnqueueAsync(
          actual_request, response,
          [actual_request, response](const tensorflow::Status& s) {
            delete actual_request;
            delete response;
          });
    }

    return tensorflow::Status::OK();
  };
}
#endif

// When !ctx->UseSendTensorRPC(), then tensors are shipped between remote
// devices by the receiver invoking the WorkerService.RecvTensor RPC *on the
// sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel).
//
// However, in some configurations the node that has the tensor to be copied
// isn't running a server (WorkerService RPC interface). For such cases,
// this function enables sending tensors using the EagerService.SendTensor RPC
// *on the receiver*.
Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h,
                             Device* recv_device, TensorHandle** result) {
#ifdef __ANDROID__
  return errors::Unimplemented(
      "Eager's remote execution is not available on Android devices.");
#else
  eager::EagerClient* eager_client;
  uint64 context_id;
  TF_RETURN_IF_ERROR(
      ctx->GetClientAndContextID(recv_device, &eager_client, &context_id));

  eager::SendTensorRequest request;
  eager::SendTensorResponse response;

  request.set_context_id(context_id);
  request.set_op_id(ctx->NextId());
  request.set_device_name(recv_device->name());

  const Tensor* tensor;
  TF_RETURN_IF_ERROR(h->Tensor(&tensor));
  tensor->AsProtoTensorContent(request.add_tensors());

  const tensorflow::uint64 id = request.op_id();

  // TODO(nareshmodi): support making this call async.
  Notification n;
  Status status;
  eager_client->SendTensorAsync(&request, &response,
                                [&n, &status](const Status& s) {
                                  status = s;
                                  n.Notify();
                                });
  n.WaitForNotification();
  if (!status.ok()) return status;

  std::function<void()> destructor =
      GetRemoteTensorDestructor(ctx, eager_client, context_id, id, 0);

  *result = new TensorHandle(id, /*output_num=*/0, /*remote_shape_node_id=*/0,
                             tensor->dtype(), std::move(destructor),
                             recv_device, recv_device, ctx);
  (*result)->SetRemoteShape(MakeUnique<TensorShape>(tensor->shape()));

  return Status::OK();
#endif
}

Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
                          int* num_retvals) {
#ifdef __ANDROID__
  return errors::Unimplemented(
      "Eager's remote execution is not available on Android devices.");
#else
  EagerContext* ctx = op->EagerContext();

  eager::EagerClient* eager_client;
  uint64 context_id;
  TF_RETURN_IF_ERROR(
      ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));

  std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
  eager::EnqueueResponse response;

  request->set_context_id(context_id);

  auto* remote_op = request->add_queue()->mutable_operation();

  for (int i = 0; i < op->Inputs().size(); i++) {
    tensorflow::Device* input_device;
    TF_RETURN_IF_ERROR(op->Inputs()[i]->Device(&input_device));
    if (op->Device() != input_device &&
        // If the expected and actual devices are on the same task, don't
        // explicitly copy, and instead depend on the copy to happen locally
        // when the op is executed on the device.
        !OnSameTask(ctx, op->Device(), input_device)) {
      // TODO(b/110044833): It's possible the same tensor gets copied to the
      // remote device repeatedly.
      TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
          op, i, op->Device(), /* run_metadata= */ nullptr,
          &(*op->MutableInputs())[i]));
    }

    tensorflow::TensorHandle* input = op->Inputs()[i];

    tensorflow::int64 op_id;
    int32 output_num;
    TF_RETURN_IF_ERROR(input->RemoteAddress(&op_id, &output_num));

    auto* remote_op_input = remote_op->add_inputs();
    remote_op_input->set_op_id(op_id);
    remote_op_input->set_output_num(output_num);
  }

  remote_op->set_id(op->EagerContext()->NextId());
  remote_op->set_name(op->Name());
  // Inputs set above.
  op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
  remote_op->set_device(op->Device()->name());

  DataTypeVector output_dtypes;
  TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));

  if (*num_retvals != output_dtypes.size()) {
    return errors::InvalidArgument(
        "num_retvals does not match expected output dtypes");
  }

  tensorflow::Device* op_device = op->Device();

  bool is_async = op->EagerContext()->Async();
  uint64 remote_node_id = 0;

  if (is_async) {
    remote_node_id = op->EagerContext()->NextId();
  }

  const tensorflow::uint64 id = remote_op->id();
  for (int i = 0; i < *num_retvals; i++) {
    // TODO(nareshmodi): Change the callback to instead add the decref to a list
    // of pending decrefs that we can send as a batch with the next execute.
    std::function<void()> destructor =
        GetRemoteTensorDestructor(ctx, eager_client, context_id, id, i);

    retvals[i] = new TensorHandle(remote_op->id(), i, remote_node_id,
                                  output_dtypes[i], std::move(destructor),
                                  op_device, op_device, op->EagerContext());
  }

  if (is_async) {
    // Copy the output handles, since the container for them might get
    // destroyed.
    gtl::InlinedVector<TensorHandle*, 2> retvals_copy;
    for (int i = 0; i < *num_retvals; i++) {
      retvals_copy.push_back(retvals[i]);
      retvals_copy[i]->Ref();
    }
    // Unable to capture via std::move, so bind instead.
    auto* node = new eager::RemoteExecuteNode(
        remote_node_id, std::move(request), eager_client, op->Inputs(),
        std::bind(
            [](const gtl::InlinedVector<TensorHandle*, 2>& retvals,
               const Status& status, const eager::EnqueueResponse& response) {
              if (!status.ok()) return;
              for (int i = 0; i < retvals.size(); i++) {
                retvals[i]->SetRemoteShape(MakeUnique<TensorShape>(
                    response.queue_response(0).shape(i)));
                retvals[i]->Unref();
              }
            },
            std::move(retvals_copy), std::placeholders::_1,
            std::placeholders::_2));
    op->EagerContext()->ExecutorAdd(node);
  } else {
    Notification n;
    Status status;
    eager_client->EnqueueAsync(request.get(), &response,
                               [&n, &status](const Status& s) {
                                 status = s;
                                 n.Notify();
                               });
    n.WaitForNotification();

    if (!status.ok()) return status;

    for (int i = 0; i < *num_retvals; i++) {
      retvals[i]->SetRemoteShape(
          MakeUnique<TensorShape>(response.queue_response(0).shape(i)));
    }
  }

  return Status::OK();
#endif
}
}  // namespace

Status EagerExecute(EagerOperation* op,
                    gtl::InlinedVector<TensorHandle*, 2>* retvals,
                    int* num_retvals) {
  bool op_is_local = IsLocal(op->EagerContext(), op->Device());

  if (op_is_local) {
    return EagerLocalExecute(op, retvals, num_retvals);
  }

  if (op->EagerContext()->LogDevicePlacement()) {
    LOG(INFO) << "Executing op " << op->Name() << " in device "
              << op->Device()->name();
  }

  return EagerRemoteExecute(op, retvals->data(), num_retvals);
}

Status EagerExecute(EagerContext* ctx, Device* device,
                    const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
                    KernelAndDevice* kernel, NodeExecStats* maybe_stats,
                    TensorHandle** retvals, int num_retvals) {
  if (device == nullptr) {
    // TODO(apassos) debug how the assignment below might return a different
    // device from the one requested above.
    device = kernel->device();
  }

  std::vector<Tensor> outputs(1);
  const MemoryTypeVector* output_memory_types = nullptr;
  output_memory_types = &kernel->kernel()->output_memory_types();
  std::vector<Tensor> inputs(op_inputs.size());
  for (int i = 0; i < op_inputs.size(); ++i) {
    const Tensor* input_tensor = nullptr;
    TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
    inputs[i] = *input_tensor;
  }
  // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
  // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def.
  // But knowledge of the implementation
  // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
  // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
  // This is quite subtle. Re-work things to make this better?  (Would it make
  // sense for FunctionLibraryRuntime to ensure thread-safe access to
  // FunctionLibraryDefinition?).  TODO(apassos) figure out how to record stats
  // for ops which are a part of functions.
  // TODO(agarwal): change Run to take vector of handles ?
  TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats));
  if (maybe_stats != nullptr) {
    int64 nanos = Env::Default()->NowNanos();
    maybe_stats->set_op_end_rel_micros(nanos / EnvTime::kMicrosToNanos -
                                       maybe_stats->all_start_micros());
    maybe_stats->set_op_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
    mutex_lock ml(*ctx->MetadataMu());
    if (ctx->ShouldStoreMetadata()) {
      auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
      // Lazily initialize the RunMetadata with information about all devices if
      // this is the first call.
      while (step_stats->dev_stats_size() < ctx->devices()->size()) {
        step_stats->add_dev_stats();
      }
      // Find the current device's index.
      int device_idx = 0;
      for (int i = 0; i < ctx->devices()->size(); ++i) {
        if (ctx->devices()->at(i) == device) {
          device_idx = i;
          break;
        }
      }
      // Populate the device stats for this device.
      auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
      dev_stats->set_device(device->name());
      *dev_stats->add_node_stats() = *maybe_stats;
    }
  }
  DCHECK_EQ(num_retvals, outputs.size());
  Device* op_device = device;
  for (int i = 0; i < num_retvals; ++i) {
    Device* d = op_device;
    if (d != nullptr && output_memory_types != nullptr &&
        (*output_memory_types)[i] == HOST_MEMORY) {
      d = nullptr;
    }
    if (retvals[i] == nullptr) {
      retvals[i] = new TensorHandle(outputs[i], d, op_device, ctx);
    } else {
      retvals[i]->SetTensorAndDevice(outputs[i], d, op_device);
    }
  }
  return Status::OK();
}

namespace {

Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
                              TensorHandle** result) {
  TF_RETURN_IF_ERROR(ctx->GetStatus());
  if (ctx->Async()) {
    // Note that `h` may not be currently ready. However execution order will
    // make sure that `h` is ready before the copy is actually done.
    CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
    TensorHandle* output = node->dst();
    // Note that calling Add makes `node` accessible by the EagerExecutor
    // thread. So further accesses need to be thread-safe.
    ctx->ExecutorAdd(node);
    *result = output;
    return Status::OK();
  } else {
    TF_RETURN_IF_ERROR(h->CopyToDevice(ctx, dstd, result));
    return Status::OK();
  }
}

Status FindDeviceFromName(EagerContext* ctx, const char* device_name,
                          Device** device) {
  *device = ctx->HostCPU();
  if (device_name == nullptr || strlen(device_name) == 0) {
    return Status::OK();
  }

  auto status = ctx->local_device_mgr()->LookupDevice(device_name, device);
  if (status.ok()) {
    return status;
  }

  if (ctx->remote_device_mgr() != nullptr) {
    return ctx->remote_device_mgr()->LookupDevice(device_name, device);
  }

  return status;
}

Status ExecuteSend(EagerContext* ctx, tensorflow::Device* device,
                   TensorHandle* h, StringPiece wire_id,
                   const string& recv_device) {
  const tensorflow::AttrTypeMap* types;
  TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp("_Send", &types));
  tensorflow::EagerOperation op(ctx, "_Send", types);

  op.AddInput(h);

  op.SetDevice(device);

  op.MutableAttrs()->Set("tensor_name", wire_id);
  op.MutableAttrs()->Set("send_device", device->name());
  op.MutableAttrs()->Set(
      "send_device_incarnation",
      static_cast<int64>(device->attributes().incarnation()));
  op.MutableAttrs()->Set("recv_device", recv_device);
  op.MutableAttrs()->Set("client_terminated", false);

  op.MutableAttrs()->Set("T", h->dtype);

  int num_outputs = 0;
  gtl::InlinedVector<TensorHandle*, 2> retvals;

  return EagerExecute(&op, &retvals, &num_outputs);
}

Status ExecuteRecv(EagerContext* ctx, tensorflow::Device* device,
                   DataType dtype, StringPiece wire_id,
                   const string& send_device, int64 send_device_incarnation,
                   TensorHandle** result) {
  const tensorflow::AttrTypeMap* types;
  TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp("_Recv", &types));
  tensorflow::EagerOperation op(ctx, "_Recv", types);

  op.SetDevice(device);

  op.MutableAttrs()->Set("tensor_name", wire_id);
  op.MutableAttrs()->Set("send_device", send_device);
  op.MutableAttrs()->Set("send_device_incarnation", send_device_incarnation);
  op.MutableAttrs()->Set("recv_device", device->name());
  op.MutableAttrs()->Set("client_terminated", false);

  op.MutableAttrs()->Set("tensor_type", dtype);

  int num_outputs = 1;
  gtl::InlinedVector<TensorHandle*, 2> retvals(num_outputs);

  TF_RETURN_IF_ERROR(EagerExecute(&op, &retvals, &num_outputs));

  *result = retvals.at(0);

  return Status::OK();
}

// This gets a unique wire ID. We add a random identifier so that if the worker
// has other clients that it is servicing, we don't have any collision.
string GetUniqueWireID() {
  static tensorflow::uint64 random_seed = random::New64();
  static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
  static tensorflow::int64 wireid GUARDED_BY(wireid_mutex) = 0;
  tensorflow::mutex_lock l(wireid_mutex);
  return strings::StrCat(random_seed, "_", wireid++);
}

}  // namespace

Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
                         const char* device_name, TensorHandle** result) {
  tensorflow::Device* send_device;
  TF_RETURN_IF_ERROR(h->Device(&send_device));

  if (send_device == nullptr) {
    send_device = ctx->HostCPU();
  }

  bool sender_is_local = IsLocal(ctx, send_device);

  tensorflow::Device* recv_device;
  TF_RETURN_IF_ERROR(FindDeviceFromName(ctx, device_name, &recv_device));

  bool recver_is_local = IsLocal(ctx, recv_device);

  if (sender_is_local && recver_is_local) {
    return LocalEagerCopyToDevice(h, ctx, recv_device, result);
  } else if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
    return EagerRemoteSendTensor(ctx, h, recv_device, result);
  } else {
    string wire_id = GetUniqueWireID();

    TF_RETURN_IF_ERROR(
        ExecuteSend(ctx, send_device, h, wire_id, recv_device->name()));

    return ExecuteRecv(ctx, recv_device, h->dtype, wire_id, send_device->name(),
                       send_device->attributes().incarnation(), result);
  }
}
}  // namespace tensorflow