aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/kernel.h
blob: 8ef091f929c0ae5a068059732b57c0729fd5be07 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Suite of datatypes to represent data-parallel kernel objects (code entities).
// Kernel is the untyped variant, whereas TypedKernel takes a type signature
// to do some template-based helper generation and give compile-time type
// checking for kernel launch parameters.
//
// Users typically don't see KernelBase, they see typed kernels, analogous to a
// typed function pointer. TypedKernels express their argument types via
// template parameters like so:
//
//  TypedKernel<DeviceMemory<int>*, int>
//
// Which expresses a data parallel kernel signature for:
//
//  void(int*, int);
//
// And for a const memory region:
//
//  TypedKernel<const DeviceMemory<int>&, int>
//
// Corresponds to a data parallel kernel signature for:
//
//  void(const int*, int)
//
// Note that kernels always have a void return type, so results typically must
// be memcpy'ied from device memory to the host.
//
// Also note that a scalar integer residing in device memory and an array of
// integers residing in device memory have the same signature: DeviceMemory<T>.
// However, in the future, checks may be added for additional safety that arrays
// of minimum sizes are passed when those minimum sizes are contractually
// expected by the kernel.
//
// For user-defined types whose definitions are appropriately shared between the
// host code doing the launching and the kernel code being launched, the user
// defined types are similarly permitted to be expressed as residing in device
// memory:
//
//  TypedKernel<DeviceMemory<MyUserDefinedStructure>>
//
// And, when the alignment and padding are agreed upon, POD types will also be
// able to be passed by value; for example, it is a common idiom to specify a
// bunch of options simultaneously with a structure:
//
//  TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>>
//
// Which corresponds to a data parallel kernel signature like:
//
//  void(MyOptionsStructurePassedByValue value, float *result);
//
// Users typically won't need to type out the TypedKernel signature in full, it
// will be typedef'd by automatically generated code; for example, see
// perftools::gputools::executor_sample::VecReduceAddKernel.

#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_

#include <array>
#include <memory>
#include <tuple>
#include <type_traits>
#include <vector>

#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/kernel_cache_config.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
#include "tensorflow/stream_executor/lib/inlined_vector.h"
#include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/platform/port.h"

namespace perftools {
namespace gputools {

class DeviceMemoryBase;
template <typename ElemT>
class DeviceMemory;
class StreamExecutor;

namespace internal {
class KernelInterface;
}  // namespace internal

// KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
// registers allocated, shared memory used, etc.
// Not all platforms support reporting of all information, so each accessor
// returns false if the associated field is not populated in the underlying
// platform.
class KernelMetadata {
 public:
  KernelMetadata()
      : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {}

  // Returns the number of registers used per thread executing this kernel.
  bool registers_per_thread(int *registers_per_thread) const;

  // Sets the number of registers used per thread executing this kernel.
  void set_registers_per_thread(int registers_per_thread);

  // Returns the amount of [static] shared memory used per block executing this
  // kernel. Note that dynamic shared memory allocations are not (and can not)
  // be reported here (since they're not specified until kernel launch time).
  bool shared_memory_bytes(int *shared_memory_bytes) const;

  // Sets the amount of [static] shared memory used per block executing this
  // kernel.
  void set_shared_memory_bytes(int shared_memory_bytes);

 private:
  // Holds the value returned by registers_per_thread above.
  bool has_registers_per_thread_;
  int registers_per_thread_;

  // Holds the value returned by shared_memory_bytes above.
  bool has_shared_memory_bytes_;
  int64 shared_memory_bytes_;
};

// A data-parallel kernel (code entity) for launching via the StreamExecutor,
// analogous to a void* device function pointer. See TypedKernel for the typed
// variant.
//
// Thread-compatible.
class KernelBase {
 public:
  KernelBase(KernelBase &&) = default;

  // Constructs an "empty" (not-yet-loaded) kernel instance.
  //
  // parent is the StreamExecutor that will be responsible for loading the
  // implementation of this kernel. It must not be null.
  explicit KernelBase(StreamExecutor *parent);

  // Test-only constructor that can take a mock KernelInterface implementation.
  KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation);

  // Releases resources associated with the kernel instance (i.e.
  // platform-specific implementation).
  ~KernelBase();

  // Returns the number of parameters that this kernel accepts. (Arity refers to
  // nullary, unary, ...).
  unsigned Arity() const;

  // Returns the StreamExecutor that represents the platform this kernel
  // executes upon.
  StreamExecutor *parent() const { return parent_; }

  // Returns a const pointer to the (opaque) platform-dependent implementation.
  const internal::KernelInterface *implementation() const {
    return implementation_.get();
  }

  // Returns a non-const pointer to the (opaque) platform-dependent
  // implementation.
  internal::KernelInterface *implementation() { return implementation_.get(); }

  void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; }

  const KernelMetadata &metadata() const { return metadata_; }

  // Sets the preferred cache configuration for a kernel. This is just a
  // suggestion to the runtime, and may not be honored during execution.
  void SetPreferredCacheConfig(KernelCacheConfig config);

  // Gets the preferred cache configuration for a kernel.
  KernelCacheConfig GetPreferredCacheConfig() const;

  void set_name(port::StringPiece name);
  const string &name() const { return name_; }
  const string &demangled_name() const { return demangled_name_; }

 private:
  // The StreamExecutor that loads this kernel object.
  StreamExecutor *parent_;

  // Implementation delegated to for platform-specific functionality.
  std::unique_ptr<internal::KernelInterface> implementation_;

  string name_;
  string demangled_name_;

  KernelMetadata metadata_;

  SE_DISALLOW_COPY_AND_ASSIGN(KernelBase);
};

// Whether T is a DeviceMemory-family pointer.
template <typename T>
struct IsDeviceMemoryPointer {
  static constexpr bool value = false;
};

template <typename U>
struct IsDeviceMemoryPointer<DeviceMemory<U> *> {
  static constexpr bool value = true;
};

template <>
struct IsDeviceMemoryPointer<DeviceMemoryBase *> {
  static constexpr bool value = true;
};

// Whether T is a DeviceMemory-family value-like thing (which includes a
// reference). This trait is useful because we pack values in the same manner as
// references.
template <typename T>
struct IsDeviceMemoryValueLike {
  static constexpr bool value = false;
};

template <typename U>
struct IsDeviceMemoryValueLike<DeviceMemory<U> &> {
  static constexpr bool value = true;
};

// We need to treat SharedDeviceMemory types differently than other DeviceMemory
// types (since they maintain no allocations), hence these specializations.
template <typename U>
struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> {
  static constexpr bool value = false;
};

template <>
struct IsDeviceMemoryValueLike<DeviceMemoryBase &> {
  static constexpr bool value = true;
};

template <typename U>
struct IsDeviceMemoryValueLike<DeviceMemory<U>> {
  static constexpr bool value = true;
};

template <typename U>
struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> {
  static constexpr bool value = false;
};

template <>
struct IsDeviceMemoryValueLike<DeviceMemoryBase> {
  static constexpr bool value = true;
};

template <typename U>
struct IsSharedDeviceMemory {
  static constexpr bool value = false;
};

template <typename U>
struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> {
  static constexpr bool value = true;
};

template <typename U>
struct IsSharedDeviceMemory<SharedDeviceMemory<U>> {
  static constexpr bool value = true;
};

// Basic data about a kernel argument.
struct KernelArg {
  bool is_shared;
  const void *address;
  size_t size;
};

// An iterator for traversing all the arguments of a KernelArgsArray.
class KernelArgIterator {
 public:
  KernelArgIterator(int number_of_argument_addresses,
                    int number_of_shared_memory_arguments,
                    const void *const *arg_addresses_data,
                    const size_t *arg_sizes_data,
                    const size_t *shmem_bytes_data,
                    const size_t *shmem_indices_data)
      : arg_index_(0),
        number_of_arguments_(number_of_argument_addresses +
                             number_of_shared_memory_arguments),
        arg_address_iter_(arg_addresses_data),
        arg_size_iter_(arg_sizes_data),
        shmem_bytes_iter_(shmem_bytes_data),
        shmem_indices_iter_(shmem_indices_data),
        shmem_indices_end_(shmem_indices_data +
                           number_of_shared_memory_arguments) {}

  // Returns true if another argument is present in the iterator.
  bool has_next() { return arg_index_ < number_of_arguments_; }

  // Returns the next argument in the iterator.
  //
  // Returns a default-constructed KernelArg if there is no next argument.
  KernelArg next() {
    KernelArg result = {};
    if (!has_next()) {
      return result;
    } else if ((shmem_indices_iter_ != shmem_indices_end_) &&
               (arg_index_ == *shmem_indices_iter_)) {
      result.is_shared = true;
      result.address = nullptr;
      result.size = *shmem_bytes_iter_;
      ++shmem_indices_iter_;
      ++shmem_bytes_iter_;
    } else {
      result.is_shared = false;
      result.address = *arg_address_iter_;
      result.size = *arg_size_iter_;
      ++arg_address_iter_;
      ++arg_size_iter_;
    }
    ++arg_index_;
    return result;
  }

 private:
  size_t arg_index_;
  size_t number_of_arguments_;
  const void *const *arg_address_iter_;
  const size_t *arg_size_iter_;
  const size_t *shmem_bytes_iter_;
  const size_t *shmem_indices_iter_;
  const size_t *const shmem_indices_end_;
};

// Base class for KernelArgsArray.
//
// Supports all the getter methods that do not depend on the compile-time number
// of arguments template parameter.
//
// This class exists as a way to pass kernel arguments to
// StreamExecutorInterface::Launch. That Launch method is virtual, so it can't
// be templated to accept any KernelArgsArray type, therfore a reference to this
// base type is passed instead.
//
// Performance is not a concern here because each of these methods will be
// called at most once per kernel launch. Past performance concerns with
// KernelArgsArray have been in reference to the argument packing routines which
// are called once per kernel argument. Those packing routines are now handled
// by the templated KernelArgsArray subclass of this class where they can take
// advantage of compile-time knowledge of the number of arguments in order to be
// very efficient.
class KernelArgsArrayBase {
 public:
  virtual ~KernelArgsArrayBase() = default;

  // Gets the number of arguments added so far, including shared memory
  // arguments.
  virtual size_t number_of_arguments() const = 0;

  // Gets the total number of shared memory bytes added so far.
  virtual uint64 number_of_shared_bytes() const = 0;

  // Gets the list of argument addresses.
  virtual port::ArraySlice<const void *> argument_addresses() const = 0;

  // Gets an iterator to the arguments in the array.
  virtual KernelArgIterator arg_iterator() const = 0;
};

// A list of arguments for a kernel call.
//
// The template parameter kNumArgs is the maximum number of arguments which can
// be stored in the list.
//
// Contains a list of addresses for non-shared-memory arguments and a list of
// sizes for shared-memory arguments. Since the shared-memory arguments may be
// interspersed with the non-shared-memory arguments, it also stores a list of
// the indices at which the shared-memory arguments appeared.
//
// For example, if the argument address list contains {a, b, c, d, e}, the
// shared-memory arguments list contains the sizes of {A, B, C}, and the
// shared-memory indices list contains {0, 3, 5}, then the original list of
// arguments was {A, a, b, B, c, C, d, e}.
//
// This way of storing the arguments makes CUDA kernel calls efficient because
// they only require the argument address list and the total number of shared
// bytes, but it also makes it possible for OpenCL kernel calls because they
// depend on the location of each shared-memory argument and its size.
//
// Note that the code for adding arguments has been identified as a performance
// hotspot in some real-world applications so this structure has been optimized
// for the performance of argument adding.
template <size_t kNumArgs>
class KernelArgsArray : public KernelArgsArrayBase {
 public:
  explicit KernelArgsArray()
      : total_shared_memory_bytes_(0),
        number_of_argument_addresses_(0),
        number_of_shared_memory_arguments_(0) {}

  // Adds an argument to the list.
  //
  // Note that the address of the argument is stored, so the input must not go
  // out of scope before the instance of this class that calls this method does.
  template <typename T>
  void add_argument(const T &arg) {
    argument_addresses_[number_of_argument_addresses_] =
        static_cast<const void *>(&arg);
    argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
    ++number_of_argument_addresses_;
  }

  // Adds a device memory argument to the list.
  void add_device_memory_argument(const DeviceMemoryBase &arg) {
    const void **copy_ptr =
        &device_memory_opaque_pointers_[number_of_argument_addresses_];
    *copy_ptr = arg.opaque();
    argument_addresses_[number_of_argument_addresses_] = copy_ptr;
    argument_sizes_[number_of_argument_addresses_] = sizeof(void *);
    ++number_of_argument_addresses_;
  }

  // Adds a shared memory argument to the list.
  //
  // The only significant information about a shared argument is its size, so
  // that is the only parameter in this function.
  void add_shared_bytes(size_t number_of_bytes) {
    shared_memory_indices_[number_of_shared_memory_arguments_] =
        number_of_argument_addresses_ + number_of_shared_memory_arguments_;
    shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes;
    ++number_of_shared_memory_arguments_;
    total_shared_memory_bytes_ += number_of_bytes;
  }

  // Gets the number of arguments added so far, including shared memory
  // arguments.
  size_t number_of_arguments() const override {
    return number_of_argument_addresses_ + number_of_shared_memory_arguments_;
  }

  // Gets the total number of shared memory bytes added so far.
  uint64 number_of_shared_bytes() const override {
    return total_shared_memory_bytes_;
  }

  // Gets the list of argument addresses.
  port::ArraySlice<const void *> argument_addresses() const override {
    return port::ArraySlice<const void *>(argument_addresses_.data(),
                                          number_of_argument_addresses_);
  }

  // Gets an iterator to the arguments in the array.
  KernelArgIterator arg_iterator() const override {
    return KernelArgIterator(
        number_of_argument_addresses_, number_of_shared_memory_arguments_,
        argument_addresses_.data(), argument_sizes_.data(),
        shared_memory_bytes_.data(), shared_memory_indices_.data());
  }

 private:
  // A place to store copies of opaque pointers from device memory arguments.
  std::array<const void *, kNumArgs> device_memory_opaque_pointers_;

  // Addresses for non-shared-memory arguments.
  std::array<const void *, kNumArgs> argument_addresses_;

  // Sizes for non-shared-memory arguments.
  std::array<size_t, kNumArgs> argument_sizes_;

  // Size in bytes for each shared memory argument.
  std::array<size_t, kNumArgs> shared_memory_bytes_;

  // Indices in the arguments array for shared memory arguments.
  std::array<size_t, kNumArgs> shared_memory_indices_;

  // Total of all shared memory sizes.
  size_t total_shared_memory_bytes_;

  // Number of significant entries in argument_addresses_ and argument_sizes_.
  size_t number_of_argument_addresses_;

  // Number of significant entries in shared_memory_bytes_ and
  // shared_memory_indices_.
  size_t number_of_shared_memory_arguments_;
};

// Typed variant of KernelBase, like a typed device function pointer. See the
// file comment for details and example usage.
//
// This class contains template metaprogramming magic to type check the
// parameters passed to a kernel launch are acceptable, and subsequently pack
// them into a form which can be used by the StreamExecutorInterface
// implementation. (i.e.  CUDA and OpenCL both bind void*s with associated
// sizes as kernel arguments.)
//
// Thread-compatible.
template <typename... Params>
class TypedKernel : public KernelBase {
 public:
  static constexpr size_t kNumberOfParameters = sizeof...(Params);

  // Delegates to KernelBase::KernelBase(), see that constructor.
  explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {}

  // Test-only constructor that can take a mock KernelInterface implementation.
  // Takes ownership of implementation, it should not be null.
  TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation)
      : KernelBase(parent, implementation) {}

 private:
  // Stream needs access to the specific parameter-packing functionality that
  // the TypedKernel provides for its corresponding type signature (and no other
  // type signatures).
  friend class Stream;

  // This is the main entry point into the magic. Packs the parameters (which
  // must type check against the class template) into the args and sizes
  // arrays.
  //
  // Const refs are taken as parameters on all of the handlers to avoid
  // implicit type promotion of integers.
  //
  // WARNING: as a performance optimization this method may store pointers to
  // some of the input parameters in the kernel args structure, so any params
  // passed into this method must live at least as long as the kernel args
  // structure.
  void PackParams(KernelArgsArray<kNumberOfParameters> *args,
                  Params &... params) const {
    PackOneParam(args, params...);
  }

  template <typename T, typename... RestOfParams>
  void PackOneParam(KernelArgsArray<kNumberOfParameters> *args, const T &arg,
                    const RestOfParams &... rest) const {
    PackOneParam(args, arg);
    PackOneParam(args, rest...);
  }

  // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array.
  // The enable_if<> is for excluding DeviceMemoryBase args, which have a
  // separate implementation below.
  template <typename T>
  void PackOneParam(
      KernelArgsArray<kNumberOfParameters> *args, const T &arg,
      typename std::enable_if<!IsDeviceMemoryValueLike<T>::value &&
                              !IsDeviceMemoryPointer<T>::value &&
                              !IsSharedDeviceMemory<T>::value>::type * =
          nullptr) const {
    static_assert(!std::is_pointer<T>::value,
                  "cannot pass raw pointer to the device");
    static_assert(!std::is_convertible<T, DeviceMemoryBase>::value,
                  "cannot pass device memory as a normal value");
    args->add_argument(arg);
  }

  // DeviceMemoryBase family reference override.
  template <typename T>
  void PackOneParam(
      KernelArgsArray<kNumberOfParameters> *args, const T &arg,
      typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * =
          nullptr) const {
    args->add_device_memory_argument(arg);
  }

  // DeviceMemoryBase family pointer override.
  template <typename T>
  void PackOneParam(
      KernelArgsArray<kNumberOfParameters> *args, T arg,
      typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * =
          nullptr) const {
    DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg);
    args->add_device_memory_argument(*ptr);
  }

  // Dynamic shared device memory has a size, but no associated allocation on
  // the host; internally, the device will allocate storage.
  template <typename T>
  void PackOneParam(
      KernelArgsArray<kNumberOfParameters> *args, T arg,
      typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * =
          nullptr) const {
    args->add_shared_bytes(arg.size());
  }

  // Base case for variadic template expansion - nothing to do!
  void PackOneParam(KernelArgsArray<kNumberOfParameters> *args) const {}

  SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
};

// Template metaprogramming helper type that helps us produce better error
// messages at compile time when the are mismatches between the parameter
// type list and the argument type list.
template <typename ParamTuple, typename ArgTuple>
struct KernelInvocationChecker {
  // Whether the parameter tuple and argument tuple match in length.
  static constexpr bool kLengthMatches =
      std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value;

  // The (matching) length of the parameters and arguments type lists.
  static constexpr int kTupleLength =
      static_cast<int>(std::tuple_size<ArgTuple>::value);

  // Helper trait to say whether the parameter wants a DeviceMemory-reference
  // compatible type. This is for inexact type matches, so that it doesn't have
  // to be precisely a const DeviceMemory<T>&, but can also be a value that
  // represents the same.
  template <typename ParamType, typename ArgType>
  struct IsCompatibleDeviceMemoryRef {
    static constexpr bool value = false;
  };

  // See type trait definition above.
  template <typename U>
  struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> {
    static constexpr bool value = true;
  };

  // See type trait definition above.
  template <typename U>
  struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &,
                                     SharedDeviceMemory<U>> {
    static constexpr bool value = true;
  };

  // Returns whether ParamT and ArgT are compatible for data parallel kernel
  // parameter packing without any assert functionality.
  template <typename ParamT, typename ArgT>
  static constexpr bool CompatibleNoAssert() {
    return std::is_same<typename std::remove_const<ParamT>::type,
                        ArgT>::value ||
           IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value;
  }

  // Checks whether ParamT and ArgT are compatible for data parallel kernel
  // parameter packing. kArgumentNumber is unused, it just for error display.
  //
  // NOTE: if you encounter an error here, you can see the mismatch by looking
  // at the end of the last error message, which will be of the form:
  //
  //    ...::Compatible<const perftools::gputools::DeviceMemory<OneThing> &,
  //                    perftools::gputools::DeviceMemory<AnotherThing>, true,
  //                    0>'
  //    requested here
  //
  // This means that the 0th argument you passed to the kernel invocation should
  // have been DeviceMemory<OneThing> but was observed to be
  // DeviceMemory<AnotherThing>.
  template <typename ParamT, typename ArgT, bool kShouldStaticAssert,
            int kArgumentNumber>
  static constexpr bool Compatible() {
    static_assert(
        kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true,
        "parameter type (LHS) is not compatible with argument type (RHS)");
    return CompatibleNoAssert<ParamT, ArgT>();
  }

  // Checks the parameter/argument match at kArgumentNumber for an out of bounds
  // argument number.
  //
  // This is the base case: we've run out of argument to check, so we're all
  // good.
  template <int kArgumentNumber, bool kShouldStaticAssert>
  static constexpr bool CheckParam(
      typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) {
    return true;
  }

  // Checks the parameter/argument match at kArgumentNumber.
  // kShouldStaticAssert determines whether to assert out on a mismatch, or just
  // yield the constexpr boolean value.
  template <int kArgumentNumber, bool kShouldStaticAssert>
  static constexpr bool CheckParam(
      typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) {
    typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type
        ParamT;
    typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT;
    return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() &&
           CheckParam<kArgumentNumber - 1, kShouldStaticAssert>();
  }

  // Checks the parameters/arguments for match, but doesn't static assert out.
  // This is useful for testing/inspecting whether a set of parameters match in
  // things like tests.
  static constexpr bool CheckAllNoStaticAssert() {
    return kLengthMatches && CheckParam<kTupleLength - 1, false>();
  }

  // Checks the parameters and static asserts out with a helpful error message
  // (and useful template parameters in the instantiation stack) if there is an
  // error.
  static constexpr bool CheckAllStaticAssert() {
    static_assert(kLengthMatches,
                  "argument length mismatched against typed kernel parameters");
    return kLengthMatches && CheckParam<kTupleLength - 1, true>();
  }
};

// This is a convenience type for checking whether a typed kernel matches
// against a type list.
template <typename KernelT, typename... Params>
struct KernelParamsOk {
  static constexpr bool kResult = false;
};

// See above.
template <typename... Params, typename... Args>
struct KernelParamsOk<TypedKernel<Params...>, Args...> {
  static constexpr bool kResult = KernelInvocationChecker<
      std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
};

}  // namespace gputools
}  // namespace perftools

#endif  // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_