aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/gpu/gpu_device.cc
blob: 02f70d835d500b5ad389d04a1af611d529595e9a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// TODO(opensource): Use a more generic sounding preprocessor name than
// GOOGLE_CUDA
#if GOOGLE_CUDA

#define EIGEN_USE_GPU

#include "tensorflow/core/common_runtime/gpu/gpu_device.h"

#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <vector>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cuda.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/stream_executor_util.h"

namespace gpu = ::perftools::gputools;

namespace tensorflow {

// Eigen Ops directly allocate memory only for temporary buffers used
// during OpKernel::Compute().  The recommended way of allocating such
// memory is via OpKernelContext::allocate_temp().  However, Eigen Ops
// don't have access to OpKernelContext, instead they get access to
// memory directly through the device allocator.  As an Open Source
// project, Eigen assumes allocator semantics similar to those of the
// CUDA memory allocator, and may not work correctly due to race
// conditions if used with some other allocator.  For safety, we need
// to delay deallocation calls out of Eigen until all events on the
// corresponding stream have completed.  The following two classes
// serve this purpose in two different compilation environments.

class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice()
      : scratch_(nullptr), semaphore_(nullptr), context_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() {}
  void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream,
                    int gpu_id, ::tensorflow::Allocator* alloc, char* scratch) {
    if (LogMemory::IsEnabled()) {
      operation_ = context->op_kernel().name() + "/EigenAllocator";
      step_id_ = context->step_id();
    }
    context_ = context;
    scratch_ = scratch;
    semaphore_ =
        reinterpret_cast<unsigned int*>(scratch + Eigen::kCudaScratchSize);
    stream_ = cuda_stream;
    allocator_ = alloc;
    device_prop_ = &Eigen::m_deviceProperties[gpu_id];
  }

  const cudaStream_t& stream() const override { return *stream_; }
  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
    void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes);
    if (ret == nullptr) {
      if (context_) {
        context_->SetStatus(errors::ResourceExhausted(
            strings::StrCat("Ran out of GPU memory when allocating ", num_bytes,
                            " bytes for ", operation_)));
      } else {
        LOG(FATAL)
            << "EigenAllocator for GPU ran out of memory when allocating "
            << num_bytes << ". See error logs for more detailed info.";
      }
    }
    if (LogMemory::IsEnabled()) {
      LogMemory::RecordRawAllocation(operation_, step_id_, num_bytes, ret,
                                     allocator_);
    }
    return ret;
  }
  void deallocate(void* buffer) const override {
    if (LogMemory::IsEnabled()) {
      LogMemory::RecordRawDeallocation(operation_, step_id_, buffer, allocator_,
                                       true);
    }
    AsyncFreeData* afData =
        new AsyncFreeData(allocator_, buffer, operation_, step_id_);
    cudaError_t err = cudaStreamAddCallback(*stream_, asyncFree, afData, 0);
    CHECK_EQ(err, cudaSuccess);
  }

  // Return a pointer to a per stream scratchpad of 1024 bytes residing
  // in global memory.
  void* scratchpad() const { return scratch_; }

  // Return a semaphore. The semaphore is initially initialized to 0, and
  // each kernel using it is responsible for resetting to 0 upon completion
  // to maintain the invariant that the semaphore is always equal to 0 upon
  // each kernel start.
  unsigned int* semaphore() const { return semaphore_; }

 private:
  struct AsyncFreeData {
    AsyncFreeData(::tensorflow::Allocator* a, void* p, const string& o,
                  const int64 s)
        : allocator_(a), address_(p), operation_(o), step_id_(s) {}
    ::tensorflow::Allocator* allocator_;
    void* address_;
    const string operation_;
    const int64 step_id_;
  };

  static void CUDART_CB asyncFree(cudaStream_t stream, cudaError_t status,
                                  void* userData) {
    AsyncFreeData* data = static_cast<AsyncFreeData*>(userData);
    if (LogMemory::IsEnabled()) {
      LogMemory::RecordRawDeallocation(data->operation_, data->step_id_,
                                       data->address_, data->allocator_, false);
    }
    data->allocator_->DeallocateRaw(data->address_);
    delete data;
  }

  string operation_;
  int64 step_id_;
  const cudaStream_t* stream_;          // Not owned.
  const cudaDeviceProp* device_prop_;   // Not owned.
  ::tensorflow::Allocator* allocator_;  // Not owned.
  mutable char* scratch_;
  mutable unsigned int* semaphore_;
  OpKernelContext* context_;

  TF_DISALLOW_COPY_AND_ASSIGN(EigenCudaStreamDevice);
};

BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
                             Bytes memory_limit, const DeviceLocality& locality,
                             int gpu_id, const string& physical_device_desc,
                             Allocator* gpu_allocator, Allocator* cpu_allocator,
                             bool sync_every_op, int32 max_streams)
    : LocalDevice(options, Device::BuildDeviceAttributes(name, DEVICE_GPU,
                                                         memory_limit, locality,
                                                         physical_device_desc)),
      gpu_allocator_(gpu_allocator),
      cpu_allocator_(cpu_allocator),
      gpu_id_(gpu_id),
      sync_every_op_(sync_every_op),
      max_streams_(max_streams) {
  ProcessState::singleton()->EnableGPUDevice();
}

BaseGPUDevice::~BaseGPUDevice() {
  delete gpu_device_info_;
  for (auto ctx : device_contexts_) ctx->Unref();
  for (auto& stream_group : streams_) {
    delete stream_group.compute;
    delete stream_group.host_to_device;
    delete stream_group.device_to_host;
    delete stream_group.device_to_device;
  }
}

Status BaseGPUDevice::Init(const SessionOptions& options) {
  auto executor_status = GPUMachineManager()->ExecutorForDevice(gpu_id_);
  if (!executor_status.status().ok()) {
    return errors::Internal("Failed to get StreamExecutor for device ",
                            gpu_id_);
  }

  executor_ = executor_status.ValueOrDie();
  em_.reset(new EventMgr(executor_, options.config.gpu_options()));

  if (max_streams_ < 1) {
    return errors::InvalidArgument("Invalid value for max_streams.");
  }

  // Create the specified number of GPU streams
  for (int i = 0; i < max_streams_; i++) {
    auto stream = new gpu::Stream(executor_);
    stream->Init();
    VLOG(2) << "Created stream[" << i << "] = " << stream;

    auto host_to_device_stream = new gpu::Stream(executor_);
    host_to_device_stream->Init();
    VLOG(2) << "Created host_to_device_stream[" << i
            << "] = " << host_to_device_stream;

    auto device_to_host_stream = new gpu::Stream(executor_);
    device_to_host_stream->Init();
    VLOG(2) << "Created device_to_host_stream[" << i
            << "] = " << device_to_host_stream;

    auto device_to_device_stream = new gpu::Stream(executor_);
    device_to_device_stream->Init();
    VLOG(2) << "Created device_to_device_stream[" << i
            << "] = " << device_to_device_stream;

    streams_.push_back({stream, host_to_device_stream, device_to_host_stream,
                        device_to_device_stream});

    size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int);
    void* scratch_buffer = gpu_allocator_->AllocateRaw(
        Allocator::kAllocatorAlignment, scratch_buffer_size);
    if (scratch_buffer == nullptr) {
      return errors::FailedPrecondition(
          "Failed to allocate scratch buffer for device ", gpu_id_);
    }
    scratch_.push_back(static_cast<char*>(scratch_buffer));

    perftools::gputools::DeviceMemory<char> mem(
        perftools::gputools::DeviceMemoryBase(scratch_buffer,
                                              scratch_buffer_size));

    bool ok = executor_->SynchronousMemZero(
        &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
    if (!ok) {
      return errors::FailedPrecondition(
          "Failed to memcopy into scratch buffer for device ", gpu_id_);
    }

    device_contexts_.push_back(
        new GPUDeviceContext(i, stream, host_to_device_stream,
                             device_to_host_stream, device_to_device_stream));
  }
  gpu_device_info_ = new GpuDeviceInfo;
  gpu_device_info_->stream = streams_[0].compute;
  gpu_device_info_->default_context = device_contexts_[0];
  gpu_device_info_->event_mgr = em_.get();
  gpu_device_info_->gpu_id = gpu_id_;
  set_tensorflow_gpu_device_info(gpu_device_info_);

  return Status::OK();
}

bool BaseGPUDevice::RequiresRecordingAccessedTensors() const {
  // When there is no more than one stream, we release the tensor reference
  // at the end of the kernel launch, instead of at the end of the kernel
  // execution.
  return streams_.size() > 1;
}

Status BaseGPUDevice::FillContextMap(const Graph* graph,
                                     DeviceContextMap* device_context_map) {
  VLOG(2) << "FillContextMap";

  const size_t num_streams = streams_.size();
  // Special case for single stream.
  if (num_streams == 1) {
    return Status::OK();
  }
  const int64 before = Env::Default()->NowMicros();
  gpu_stream_util::AssignStreamsOpts opts;
  opts.max_streams = static_cast<int32>(num_streams);
  std::unordered_map<int, int> node_to_stream_id;
  TF_RETURN_IF_ERROR(
      gpu_stream_util::AssignStreams(graph, opts, &node_to_stream_id));
  int64 elapsed = Env::Default()->NowMicros() - before;
  VLOG(3) << "AssignStreams took " << elapsed << "us";

  // Fill in the context map.  It is OK for this map to contain
  // duplicate DeviceContexts so long as we increment the refcount.
  device_context_map->resize(graph->num_node_ids());
  for (Node* n : graph->nodes()) {
    auto mapped_stream = node_to_stream_id[n->id()];
    CHECK_LE(mapped_stream, num_streams);
    auto ctx = device_contexts_[mapped_stream];
    VLOG(3) << "Assigned stream " << node_to_stream_id[n->id()]
            << " ==> stream[" << ctx->stream_id() << "] for node id " << n->id()
            << " " << n->type_string() << " " << n->name();
    ctx->Ref();
    (*device_context_map)[n->id()] = ctx;
  }

  return Status::OK();
}

void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
  // ScopedActivity is cheap when tracing is not active, but we
  // can avoid computing the Hash64.
  // TODO(pbar) This would no longer be needed if Ops have a unique id.
  const uint64 id = port::Tracing::IsActive() ? Hash64(op_kernel->name()) : 0;
  port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
                                       id);

  // NOTE(tucker): We need to discriminate between Eigen GPU
  // operations and all others.  If an operation is Eigen
  // implemented (or otherwise tries to launch a cuda kernel
  // directly), we need to establish a stacked-scoped environment
  // that directs it to execute on the proper device.  Otherwise we
  // expect the Op to use StreamExecutor directly and correctly.  The
  // way we make this discrimination is quite hacky: At the moment
  // the only non-Eigen GPU Op is the recv-op, which is known to be
  // asynchronous.
  if (op_kernel->is_internal() && op_kernel->type_string() == "_Recv") {
    context->SetStatus(errors::Internal(
        "Invalid synchronous 'Compute' on GPU for '_Recv' op"));
  } else if (port::Tracing::ScopedAnnotation::Enabled()) {
    port::Tracing::ScopedAnnotation annotation(op_kernel->name(),
                                               op_kernel->type_string());
    ComputeHelper(op_kernel, context);
  } else {
    ComputeHelper(op_kernel, context);
  }
}

void BaseGPUDevice::ComputeHelper(OpKernel* op_kernel,
                                  OpKernelContext* context) {
  GPUDeviceContext* gpu_device_context = device_contexts_[0];
  if (context->op_device_context() != nullptr) {
    gpu_device_context =
        static_cast<GPUDeviceContext*>(context->op_device_context());
  }
  gpu::Stream* stream = gpu_device_context->stream();
  const auto stream_id = gpu_device_context->stream_id();

  const bool vlog_1 = VLOG_IS_ON(1);
  const bool vlog_2 = vlog_1 && VLOG_IS_ON(2);

  if (vlog_1) {
    VLOG(1) << "GpuDevice::Compute " << op_kernel->name() << " op "
            << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream["
            << stream_id << "]";
  }

  const auto num_streams = streams_.size();
  if (num_streams > 1) {
    // If this op's device context is different from the other contexts,
    // we must wait on the stream.
    for (int i = 0; i < context->num_inputs(); ++i) {
      const GPUDeviceContext* idc =
          static_cast<GPUDeviceContext*>(context->input_device_context(i));
      OP_REQUIRES(context, idc != nullptr,
                  errors::Internal("Input device context ", i,
                                   " was not set properly."));
      if (vlog_2) {
        const void* base;
        size_t len;
        if (context->has_input(i)) {
          if (IsRefType(context->input_dtype(i))) {
            Tensor tensor = context->mutable_input(i, false);
            base = DMAHelper::base(&tensor);
            len = tensor.TotalBytes();
          } else {
            const Tensor& tensor = context->input(i);
            base = DMAHelper::base(&tensor);
            len = tensor.TotalBytes();
          }
          LOG(INFO) << "Input " << i << " " << base << "  " << len;
          LOG(INFO) << "  stream[" << stream_id << "].ThenWaitFor(stream["
                    << idc->stream_id() << "])"
                    << ((idc->stream() == stream) ? " not needed" : "");
        }
      }
      if (idc->stream() != stream) stream->ThenWaitFor(idc->stream());
    }
  }
  gpu::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
  op_kernel->Compute(context);
  if (context->status().ok()) {
    if (sync_every_op_) {
      // Note: GPUUtil::Sync() only syncs the default stream.
      // We need to either sync the stream used by this op, or
      // all streams.  Given that this flag is typically used for
      // debugging it makes more sense to sync all GPU activity.
      context->SetStatus(GPUUtil::SyncAll(this));
    }
  }
}

void BaseGPUDevice::ConsumeListOfAccessedTensors(
    DeviceContext* device_context, const TensorReferenceVector& tensor_refs) {
  GPUDeviceContext* gpu_device_context = device_contexts_[0];
  if (device_context != nullptr) {
    gpu_device_context = static_cast<GPUDeviceContext*>(device_context);
  }
  gpu::Stream* stream = gpu_device_context->stream();
  em_->ThenDeleteTensors(stream, tensor_refs);
}

// Based on the semantics of Device::Sync this call should wait for
// all streams not just the current one.
Status BaseGPUDevice::Sync() { return GPUUtil::SyncAll(this); }

void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel,
                                 OpKernelContext* context,
                                 AsyncOpKernel::DoneCallback done) {
  GPUDeviceContext* gpu_device_context = device_contexts_[0];
  if (context->op_device_context() != nullptr) {
    gpu_device_context =
        static_cast<GPUDeviceContext*>(context->op_device_context());
  }
  const auto stream_id = gpu_device_context->stream_id();

  VLOG(1) << "GpuDevice::ComputeAsync " << op_kernel->name() << " op "
          << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream["
          << stream_id << "]";

  // When TraceMe profiling is off (which is the default), the
  // following TraceMe constructor is simply a conditional test of
  // false value. Measurements show that its overhead is negligible.
  port::Tracing::TraceMe activity(op_kernel->name(), op_kernel->type_string());
  op_kernel->ComputeAsync(context, done);
}

Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
                                          const AllocatorAttributes alloc_attrs,
                                          Tensor* tensor) {
  AllocatorAttributes attr;
  attr.set_on_host(true);
  attr.set_gpu_compatible(true);
  Allocator* host_alloc = GetAllocator(attr);
  Tensor parsed(tensor_proto.dtype());
  if (!parsed.FromProto(host_alloc, tensor_proto)) {
    return errors::InvalidArgument("Cannot parse tensor from proto: ",
                                   tensor_proto.DebugString());
  }
  Status status;
  if (alloc_attrs.on_host()) {
    *tensor = parsed;
  } else {
    if (!DMAHelper::CanUseDMA(&parsed)) {
      return errors::Internal("GPU copy from non-DMA ",
                              DataTypeString(parsed.dtype()), " tensor");
    }
    Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());

    // If the tensor is not initialized, we likely ran out of memory.
    if (!copy.IsInitialized()) {
      return errors::ResourceExhausted(
          "OOM when allocating tensor of shape ", parsed.shape().DebugString(),
          " and type ", DataTypeString(parsed.dtype()));
    }

    port::Tracing::ScopedAnnotation annotation("MakeTensorFromProto");
    Notification n;
    device_contexts_[0]->CopyCPUTensorToDevice(&parsed, this, &copy,
                                               [&n, &status](const Status& s) {
                                                 status = s;
                                                 n.Notify();
                                               });
    n.WaitForNotification();
    *tensor = copy;
  }
  return status;
}

namespace {
class ConcretePerOpGpuDevice : public PerOpGpuDevice {
 public:
  ConcretePerOpGpuDevice() : device_(&stream_device_) {}

  void Reinitialize(OpKernelContext* context, const cudaStream_t* cuda_stream,
                    int gpu_id, Allocator* base_allocator, char* scratch) {
    stream_device_.Reinitialize(context, cuda_stream, gpu_id, base_allocator,
                                scratch);
  }

  const Eigen::GpuDevice& device() const override { return device_; }

 private:
  EigenCudaStreamDevice stream_device_;
  Eigen::GpuDevice device_;
};
}  // namespace

void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context,
                                       PerOpGpuDevice* device, int stream_id,
                                       Allocator* allocator) {
  ConcretePerOpGpuDevice* concrete_device =
      static_cast<ConcretePerOpGpuDevice*>(device);
  DCHECK(concrete_device);
  const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
      streams_[stream_id].compute->implementation()->CudaStreamMemberHack());
  concrete_device->Reinitialize(context, cuda_stream, gpu_id_, allocator,
                                scratch_[stream_id]);
}

PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() {
  return new ConcretePerOpGpuDevice();
}

void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
                                          PerOpGpuDevice* device,
                                          DeviceContext* dc,
                                          Allocator* allocator) {
  if (dc) {
    const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc);
    const int stream_id = gpu_dc->stream_id();
    VLOG(1) << "  eigen_gpu_device(" << dc << ") => stream[" << stream_id
            << "]";
    CHECK_LT(stream_id, streams_.size());
    ReinitializeDevice(context, device, stream_id, allocator);
  } else {
    ReinitializeDevice(context, device, 0, allocator);
  }
}

Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
                                           const string& name_prefix,
                                           std::vector<Device*>* devices) {
  size_t n = INT_MAX;
  auto iter = options.config.device_count().find("GPU");
  if (iter != options.config.device_count().end()) {
    n = iter->second;
  }
  std::vector<int> valid_gpu_ids;
  TF_RETURN_IF_ERROR(GetValidDeviceIds(
      options.config.gpu_options().visible_device_list(), &valid_gpu_ids));
  if (static_cast<size_t>(n) > valid_gpu_ids.size()) {
    n = valid_gpu_ids.size();
  }
  for (int i = 0; i < n; i++) {
    BaseGPUDevice* gpu_device;
    TF_RETURN_IF_ERROR(CreateGPUDevice(options,
                                       strings::StrCat(name_prefix, "/gpu:", i),
                                       valid_gpu_ids[i], &gpu_device));
    TF_RETURN_IF_ERROR(gpu_device->Init(options));
    devices->push_back(gpu_device);
  }

  return Status::OK();
}

namespace {
int64 MinSystemMemory(int64 available_memory) {
  // We use the following heuristic for now:
  //
  // If the available_memory is < 2GiB, we allocate 200MiB to system memory.
  // Otherwise, allocate 300MiB to system memory.
  //
  // In the future we could be more sophisticated by using a table of
  // devices.
  if (available_memory < (1LL << 31)) {
    // 200MiB
    return 209715200LL;
  } else {
    // max(300 MiB, 0.95 * available_memory)
    return std::max(314572800LL, static_cast<int64>(available_memory * 0.05));
  }
}
}  // namespace

static string GetShortDeviceDescription(int device_id,
                                        const gpu::DeviceDescription& desc) {
  return strings::StrCat("device: ", device_id, ", name: ", desc.name(),
                         ", pci bus id: ", desc.pci_bus_id());
}

Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
                                             const string& name, int gpu_id,
                                             BaseGPUDevice** out_device) {
  CHECK_GE(gpu_id, 0);

  // Look up the device, to see its attributes.
  gpu::Platform* gpu_platform = GPUMachineManager();
  CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount());
  gpu::StreamExecutor* se =
      gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
  const gpu::DeviceDescription& desc = se->GetDeviceDescription();
  int numa_node = desc.numa_node();
  if (numa_node < 0) {
    // For some reason the StreamExecutor couldn't get the NUMA
    // affinity of the GPU.  If this is not a multi-socket mobo with
    // GPUs local to different buses, it doesn't matter.  If it is, we
    // may run into trouble later with data transfer operations.  The
    // trouble may manifest as slower than expected performance, or
    // outright failures.
    LOG(INFO) << "Could not identify NUMA node of " << name
              << ", defaulting to 0.  Your kernel may not have been built "
              << "with NUMA support.";
    numa_node = 0;
  }

  int64 total_memory, available_memory;
  CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));

  int64 allocated_memory;
  double config_memory_fraction =
      options.config.gpu_options().per_process_gpu_memory_fraction();
  if (config_memory_fraction == 0) {
    allocated_memory = available_memory;
    const int64 min_system_memory = MinSystemMemory(available_memory);
    if (min_system_memory < allocated_memory) {
      allocated_memory -= min_system_memory;
    }
  } else {
    allocated_memory = total_memory * config_memory_fraction;
  }

  Bytes allocated_bytes = static_cast<Bytes>(allocated_memory);

  // Get GPU bus_id from its reported NUMA affinity.  Because GPUs are
  // virtualized in some environments, we can't just use the GPU id.
  // NUMA locales are indexed from 0, buses are indexed from 1.
  DeviceLocality dev_locality;
  dev_locality.set_bus_id(numa_node + 1);
  VLOG(1) << "GPUDevice id " << gpu_id << " on bus " << dev_locality.bus_id()
          << " numa: " << numa_node << " pci: " << desc.pci_bus_id();

  ProcessState* process_state = ProcessState::singleton();
  *out_device = CreateGPUDevice(
      options, name, allocated_bytes, dev_locality, gpu_id,
      GetShortDeviceDescription(gpu_id, desc),
      process_state->GetGPUAllocator(options.config.gpu_options(), gpu_id,
                                     allocated_memory),
      process_state->GetCPUAllocator(numa_node));

  return Status::OK();
}

static int GetDefaultMinGPUMultiprocessorCount(
    gpu::Platform* gpu_manager, const std::vector<int>& visible_gpu_order) {
  static const int kDefaultMinGPUMultiprocessorCount = 8;

  // Find the highest multi-processor count across all visible GPUs.
  int max_count = -1;
  for (int i = 0; i < visible_gpu_order.size(); ++i) {
    auto exec_status = gpu_manager->ExecutorForDevice(visible_gpu_order[i]);
    if (!exec_status.ok()) {
      continue;
    }

    gpu::StreamExecutor* se = exec_status.ValueOrDie();
    const gpu::DeviceDescription& desc = se->GetDeviceDescription();
    max_count = std::max(max_count, desc.core_count());
  }

  if (max_count < 0 || kDefaultMinGPUMultiprocessorCount < max_count) {
    return kDefaultMinGPUMultiprocessorCount;
  } else {
    return max_count;
  }
}

static int GetMinGPUMultiprocessorCount(
    gpu::Platform* gpu_manager, const std::vector<int>& visible_gpu_order) {
  const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT");

  if (tf_min_gpu_core_count == nullptr ||
      strcmp(tf_min_gpu_core_count, "") == 0) {
    return GetDefaultMinGPUMultiprocessorCount(gpu_manager, visible_gpu_order);
  }

  int min_gpu_core_count = -1;
  if (strings::safe_strto32(tf_min_gpu_core_count, &min_gpu_core_count)) {
    if (min_gpu_core_count >= 0) {
      return min_gpu_core_count;
    }
  }

  int count =
      GetDefaultMinGPUMultiprocessorCount(gpu_manager, visible_gpu_order);
  LOG(ERROR) << "Invalid minimum GPU multiprocessor count: ["
             << tf_min_gpu_core_count << "]. "
             << "Using the default value: " << count;
  return count;
}

namespace {

struct CudaVersion {
  // Initialize from version_name in the form of "3.5"
  explicit CudaVersion(const std::string& version_name) {
    size_t dot_pos = version_name.find('.');
    CHECK(dot_pos != string::npos)
        << "Illegal version name: [" << version_name << "]";
    string major_str = version_name.substr(0, dot_pos);
    CHECK(strings::safe_strto32(major_str, &major_part))
        << "Illegal version name: [" << version_name << "]";
    string minor_str = version_name.substr(dot_pos + 1);
    CHECK(strings::safe_strto32(minor_str, &minor_part))
        << "Illegal version name: [" << version_name << "]";
  }
  CudaVersion() {}
  bool operator<(const CudaVersion& other) const {
    if (this->major_part != other.major_part) {
      return this->major_part < other.major_part;
    }
    return this->minor_part < other.minor_part;
  }
  friend std::ostream& operator<<(std::ostream& os,
                                  const CudaVersion& version) {
    os << version.major_part << "." << version.minor_part;
    return os;
  }
  int major_part = -1;
  int minor_part = -1;
};

std::vector<CudaVersion> supported_cuda_compute_capabilities = {
    TF_CUDA_CAPABILITIES,};

std::vector<CudaVersion> GetSupportedCudaComputeCapabilities() {
  auto cuda_caps = supported_cuda_compute_capabilities;
#ifdef TF_EXTRA_CUDA_CAPABILITIES
// TF_EXTRA_CUDA_CAPABILITIES should be defined a sequence separated by commas,
// for example:
//   TF_EXTRA_CUDA_CAPABILITIES=3.0,4.0,5.0
// Use two-level macro expansion for stringification.
#define TF_XSTRING(...) #__VA_ARGS__
#define TF_STRING(s) TF_XSTRING(s)
  string extra_cuda_caps = TF_STRING(TF_EXTRA_CUDA_CAPABILITIES);
#undef TF_STRING
#undef TF_XSTRING
  auto extra_capabilities = str_util::Split(extra_cuda_caps, ',');
  for (const auto& capability : extra_capabilities) {
    cuda_caps.push_back(CudaVersion(capability));
  }
#endif
  return cuda_caps;
}

std::unique_ptr<std::map<std::pair<int, int>, bool>> GetPeerAccessMap(
    gpu::Platform* platform, const std::vector<int>& visible_gpu_order) {
  std::unique_ptr<std::map<std::pair<int, int>, bool>> map(
      new std::map<std::pair<int, int>, bool>);
  for (int i = 0; i < visible_gpu_order.size(); ++i) {
    const int i_gpu_id = visible_gpu_order[i];
    for (int j = 0; j < visible_gpu_order.size(); ++j) {
      const int j_gpu_id = visible_gpu_order[j];
      gpu::StreamExecutor* from =
          platform->ExecutorForDevice(i_gpu_id).ValueOrDie();
      gpu::StreamExecutor* to =
          platform->ExecutorForDevice(j_gpu_id).ValueOrDie();
      (*map)[{i, j}] = from->CanEnablePeerAccessTo(to);
    }
  }

  return map;
}

Status EnablePeerAccess(gpu::Platform* platform,
                        const std::vector<int>& visible_gpu_order) {
  int possible_peer_count = 0;
  int enabled_peer_count = 0;
  for (int i = 0; i < visible_gpu_order.size(); ++i) {
    const int i_gpu_id = visible_gpu_order[i];
    for (int j = 0; j < visible_gpu_order.size(); ++j) {
      const int j_gpu_id = visible_gpu_order[j];
      // We have already validated that ExecutorForDevice() calls
      // return OK.
      gpu::StreamExecutor* from =
          platform->ExecutorForDevice(i_gpu_id).ValueOrDie();
      gpu::StreamExecutor* to =
          platform->ExecutorForDevice(j_gpu_id).ValueOrDie();

      if (from->CanEnablePeerAccessTo(to)) {
        ++possible_peer_count;
        auto status = from->EnablePeerAccessTo(to);
        if (!status.ok()) {
          LOG(WARNING)
              << "Unable to enable peer access between device ordinals "
              << i_gpu_id << " and " << j_gpu_id;
        } else {
          ++enabled_peer_count;
        }
      } else {
        LOG(INFO) << "Peer access not supported between device ordinals "
                  << i_gpu_id << " and " << j_gpu_id;
      }
    }
  }

  // Return an error in the extreme failure case where the driver
  // reported that peering was possible but not a single peering was
  // successful.  This is to catch possible system misconfigurations
  // or more fundamental issues.
  if (possible_peer_count > 0 && enabled_peer_count == 0) {
    return errors::Internal(possible_peer_count,
                            " potential peer access pairs were reported by the "
                            "driver, but no peering could be enabled.");
  }
  return Status::OK();
}

}  // namespace

Status BaseGPUDeviceFactory::GetValidDeviceIds(
    const string& visible_device_list, std::vector<int>* ids) {
  TF_RETURN_IF_ERROR(ValidateGPUMachineManager());

  gpu::Platform* gpu_manager = GPUMachineManager();
  if (gpu_manager == nullptr) {
    return Status::OK();
  }

  // If there are no GPUs visible, do nothing.
  if (gpu_manager->VisibleDeviceCount() <= 0) {
    return Status::OK();
  }

  // If the user wants to remap the visible to virtual GPU mapping,
  // check for that here.
  std::vector<int> visible_gpu_order;
  if (visible_device_list.empty()) {
    visible_gpu_order.resize(gpu_manager->VisibleDeviceCount());
    // By default, visible to virtual mapping is unchanged.
    int deviceNo = 0;
    std::generate(visible_gpu_order.begin(), visible_gpu_order.end(),
                  [&deviceNo] { return deviceNo++; });
  } else {
    std::vector<string> order_str = str_util::Split(visible_device_list, ',');
    for (int i = 0; i < order_str.size(); ++i) {
      const string& gpu_id_str = order_str[i];
      int32 gpu_id;
      if (!strings::safe_strto32(gpu_id_str, &gpu_id)) {
        return errors::InvalidArgument(
            "Could not parse entry in 'visible_device_list': '", gpu_id_str,
            "'.  visible_device_list = ", visible_device_list);
      }

      if (gpu_id < 0 || gpu_id >= gpu_manager->VisibleDeviceCount()) {
        return errors::InvalidArgument(
            "'visible_device_list' listed an invalid GPU id '", gpu_id,
            "' but visible device count is ",
            gpu_manager->VisibleDeviceCount());
      }

      visible_gpu_order.push_back(gpu_id);
    }
  }

  // Validate no repeats.
  std::set<int> visible_device_set(visible_gpu_order.begin(),
                                   visible_gpu_order.end());
  if (visible_device_set.size() != visible_gpu_order.size()) {
    return errors::InvalidArgument(
        "visible_device_list contained "
        "a duplicate entry: ",
        visible_device_list);
  }

  bool new_gpu_found = false;
  for (int i = 0; i < visible_gpu_order.size(); ++i) {
    int gpu_id = visible_gpu_order[i];

    // Only perform this once per visible gpu id.
    if (visible_gpu_initialized_[gpu_id]) {
      continue;
    }

    visible_gpu_initialized_[gpu_id] = true;
    new_gpu_found = true;

    auto executor = gpu_manager->ExecutorForDevice(gpu_id);
    if (!executor.ok()) {
      return StreamExecutorUtil::ConvertStatus(executor.status());
    }

    auto stream_exec = executor.ValueOrDie();
    int64 free_bytes;
    int64 total_bytes;
    if (!stream_exec->DeviceMemoryUsage(&free_bytes, &total_bytes)) {
      // Logs internally on failure.
      free_bytes = 0;
      total_bytes = 0;
    }
    const auto& description = stream_exec->GetDeviceDescription();
    int cc_major;
    int cc_minor;
    if (!description.cuda_compute_capability(&cc_major, &cc_minor)) {
      // Logs internally on failure.
      cc_major = 0;
      cc_minor = 0;
    }
    LOG(INFO) << "Found device " << i << " with properties: "
              << "\nname: " << description.name() << "\nmajor: " << cc_major
              << " minor: " << cc_minor << " memoryClockRate (GHz) "
              << description.clock_rate_ghz() << "\npciBusID "
              << description.pci_bus_id() << "\nTotal memory: "
              << strings::HumanReadableNumBytes(total_bytes)
              << "\nFree memory: "
              << strings::HumanReadableNumBytes(free_bytes);
  }

  if (new_gpu_found) {
    // Enable peer access
    TF_RETURN_IF_ERROR(EnablePeerAccess(gpu_manager, visible_gpu_order));

    // Print out a matrix showing which devices can DMA to one
    // another.
    auto access_map = GetPeerAccessMap(gpu_manager, visible_gpu_order);
    string line_buf = "DMA: ";
    for (int i = 0; i < visible_gpu_order.size(); ++i) {
      strings::StrAppend(&line_buf, visible_gpu_order[i], " ");
    }
    LOG(INFO) << line_buf;
    for (int i = 0; i < visible_gpu_order.size(); ++i) {
      line_buf = strings::StrCat(visible_gpu_order[i], ":   ");
      for (int j = 0; j < visible_gpu_order.size(); ++j) {
        if ((*access_map)[{i, j}]) {
          line_buf.append("Y ");
        } else {
          line_buf.append("N ");
        }
      }
      LOG(INFO) << line_buf;
    }
  }

  auto cuda_supported_capabilities = GetSupportedCudaComputeCapabilities();
  if (cuda_supported_capabilities.empty()) {
    return errors::FailedPrecondition(
        "No supported cuda capabilities in binary.");
  }
  CudaVersion min_supported_capability = *std::min_element(
      cuda_supported_capabilities.begin(), cuda_supported_capabilities.end());

  int min_gpu_core_count =
      GetMinGPUMultiprocessorCount(gpu_manager, visible_gpu_order);

  // Filter out devices that don't have the right capability or power.
  for (int i = 0; i < visible_gpu_order.size(); ++i) {
    const int32 visible_gpu_id = visible_gpu_order[i];
    auto exec_status = gpu_manager->ExecutorForDevice(visible_gpu_id);
    if (!exec_status.ok()) {
      continue;
    }
    gpu::StreamExecutor* se = exec_status.ValueOrDie();
    const gpu::DeviceDescription& desc = se->GetDeviceDescription();
    CudaVersion device_capability;
    if (!desc.cuda_compute_capability(&device_capability.major_part,
                                      &device_capability.minor_part)) {
      continue;
    }
    // Only GPUs with no less than the minimum supported compute capability is
    // accepted.
    if (device_capability < min_supported_capability) {
      LOG(INFO) << "Ignoring visible gpu device "
                << "(" << GetShortDeviceDescription(visible_gpu_id, desc)
                << ") "
                << "with Cuda compute capability " << device_capability
                << ". The minimum required Cuda capability is "
                << min_supported_capability << ".";
      continue;
    }

    // Filter out slow GPUs. By default, GPUs with a lower multiprocessor
    // count than the fastest GPU are filtered out, unless they have 8 or more
    // multiprocessors. If the TF_MIN_GPU_MULTIPROCESSOR_COUNT environment
    // variable is set, its value will be used to filter out GPUs.
    if (desc.core_count() < min_gpu_core_count) {
      LOG(INFO) << "Ignoring gpu device "
                << "(" << GetShortDeviceDescription(visible_gpu_id, desc)
                << ") "
                << "with Cuda multiprocessor count: " << desc.core_count()
                << ". The minimum required count is " << min_gpu_core_count
                << ". You can adjust this requirement with the env var "
                   "TF_MIN_GPU_MULTIPROCESSOR_COUNT.";
      continue;
    }

    size_t new_id = ids->size();
    ids->push_back(visible_gpu_id);

    LOG(INFO) << "Creating TensorFlow device (/gpu:" << new_id << ") -> "
              << "(" << GetShortDeviceDescription(visible_gpu_id, desc) << ")";
  }

  return Status::OK();
}

}  // namespace tensorflow

#endif  // GOOGLE_CUDA