aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/variant_op_registry.h
blob: 7eb37e859f51992cf74a12736f5099839db5e1fd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_

#include <string>
#include <unordered_set>
#include <vector>

#define EIGEN_USE_THREADS

#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/abi.h"

namespace tensorflow {

class OpKernelContext;
// A global UnaryVariantOpRegistry is used to hold callback functions
// for different variant types.  To be used by ShapeOp, RankOp, and
// SizeOp, decoding, etc.

enum VariantUnaryOp {
  INVALID_VARIANT_UNARY_OP = 0,
  ZEROS_LIKE_VARIANT_UNARY_OP = 1,
  CONJ_VARIANT_UNARY_OP = 2,
};

enum VariantBinaryOp {
  INVALID_VARIANT_BINARY_OP = 0,
  ADD_VARIANT_BINARY_OP = 1,
};

enum VariantDeviceCopyDirection {
  INVALID_DEVICE_COPY_DIRECTION = 0,
  HOST_TO_DEVICE = 1,
  DEVICE_TO_HOST = 2,
  DEVICE_TO_DEVICE = 3,
};

class UnaryVariantOpRegistry {
 public:
  typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn;
  typedef std::function<bool(Variant*)> VariantDecodeFn;
  typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
      VariantUnaryOpFn;
  typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
                               Variant*)>
      VariantBinaryOpFn;

  // An AsyncTensorDeviceCopyFn is a function provided to
  // the user-provided DeviceCopyFn callback as the third argument ("copier").
  //
  // Expected inputs:
  //   from: A Tensor on the host (if performing cpu->gpu copy), or
  //         device (if performing gpu->cpu or gpu->gpu copy).
  //   to: An empty/uninitialized tensor.  It will be updated upon
  //       successful return of the function with the correct dtype and shape.
  //       However, the copied data will not be available until the compute
  //       stream has been synchronized.
  //
  // Returns:
  //   The status upon memory allocation / initialization of the
  //   "to" tensor, and enqueue of the copy onto the compute stream.
  //   Any failure of the copy itself will update the underlying
  //   stream status and propagate through the runtime independent
  //   of the caller.
  typedef std::function<Status(const Tensor& from, Tensor* to)>
      AsyncTensorDeviceCopyFn;

  // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn'
  // expected to be passed to the registration macro
  // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION.
  typedef std::function<Status(const Variant& from, Variant* to,
                               AsyncTensorDeviceCopyFn copy_fn)>
      AsyncVariantDeviceCopyFn;

  // Add a shape lookup function to the registry.
  void RegisterShapeFn(const TypeIndex& type_index,
                       const VariantShapeFn& shape_fn);

  // Returns nullptr if no shape function was found for the given TypeIndex.
  VariantShapeFn* GetShapeFn(const TypeIndex& type_index);

  // Add a decode function to the registry.
  void RegisterDecodeFn(const string& type_name,
                        const VariantDecodeFn& decode_fn);

  // Returns nullptr if no decode function was found for the given TypeName.
  VariantDecodeFn* GetDecodeFn(StringPiece type_name);

  // Add a copy-to-GPU function to the registry.
  void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
                            const TypeIndex& type_index,
                            const AsyncVariantDeviceCopyFn& device_copy_fn);

  // Returns nullptr if no copy function was found for the given
  // TypeName and direction.
  AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
      const VariantDeviceCopyDirection direction, const TypeIndex& type_index);

  // Add a unary op function to the registry.
  void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
                         const TypeIndex& type_index,
                         const VariantUnaryOpFn& unary_op_fn);

  // Returns nullptr if no unary op function was found for the given
  // op, device, and TypeName.
  VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
                                 const TypeIndex& type_index);

  // Add a binary op function to the registry.
  void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
                          const TypeIndex& type_index,
                          const VariantBinaryOpFn& add_fn);

  // Returns nullptr if no binary op function was found for the given
  // op, device and TypeName.
  VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
                                   const TypeIndex& type_index);

  // Get a pointer to a global UnaryVariantOpRegistry object
  static UnaryVariantOpRegistry* Global();

  // Get a pointer to a global persistent string storage object.
  // ISO/IEC C++ working draft N4296 clarifies that insertion into an
  // std::unordered_set does not invalidate memory locations of
  // *values* inside the set (though it may invalidate existing
  // iterators).  In other words, one may safely point a StringPiece to
  // a value in the set without that StringPiece being invalidated by
  // future insertions.
  static std::unordered_set<string>* PersistentStringStorage();

 private:
  struct TypeIndexHash {
    std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
  };

  gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns;
  gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;

  // Map std::pair<Direction, type_name> to function.
  struct PairHash {
    template <typename Direction>
    std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
      // The hash of an enum is just its value as a std::size_t.
      std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
      ret = Hash64Combine(ret, std::get<1>(x).hash_code());
      return ret;
    }
  };

  gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
               AsyncVariantDeviceCopyFn, PairHash>
      device_copy_fns;

  // Map std::tuple<Op, device, type_name> to function.

  // this breaks by falling victim to "too perfect forwarding"
  // see https://stackoverflow.com/questions/44475317/variadic-template-issue
  // and references therein
  template <typename Op>
  struct FuncTuple {
    FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
        : op_type_(op), device_(dev), type_index_(type_index) {}
    Op op_type_;
    StringPiece device_;
    TypeIndex type_index_;
  };
  // friend declaration for operator==
  // needed for clang
  template <typename Op>
  friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
  struct TupleHash {
    template <typename Op>
    std::size_t operator()(
        const std::tuple<Op, StringPiece, TypeIndex>& x) const {
      // The hash of an enum is just its value as a std::size_t.
      std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
      ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
      ret = Hash64Combine(ret, std::get<2>(x).hash_code());
      return ret;
    }

    template <typename Op>
    std::size_t operator()(const FuncTuple<Op>& x) const {
      // The hash of an enum is just its value as a std::size_t.
      std::size_t ret = static_cast<std::size_t>(x.op_type_);
      ret = Hash64Combine(ret, sp_hasher_(x.device_));
      ret = Hash64Combine(ret, x.type_index_.hash_code());
      return ret;
    }
    StringPieceHasher sp_hasher_;
  };
  gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
      unary_op_fns;
  gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
      binary_op_fns;

  // Find or insert a string into a persistent string storage
  // container; return the StringPiece pointing to the permanent string
  // location.
  static StringPiece GetPersistentStringPiece(const string& str) {
    const auto string_storage = PersistentStringStorage();
    auto found = string_storage->find(str);
    if (found == string_storage->end()) {
      auto inserted = string_storage->insert(str);
      return StringPiece(*inserted.first);
    } else {
      return StringPiece(*found);
    }
  }
};
template <typename Op>
inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
                       const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
  return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
         (lhs.type_index_ == rhs.type_index_);
}
// Gets a TensorShape from a Tensor containing a scalar Variant.
// Returns an Internal error if the Variant does not have a registered shape
// function, or if it's a serialized Variant that cannot be decoded.
//
// REQUIRES:
//   variant_tensor.dtype() == DT_VARIANT
//   variant_tensor.dims() == 0
//
Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape);

// Decodes the Variant whose data_type has a registered decode
// function.  Returns an Internal error if the Variant does not have a
// registered decode function, or if the decoding function fails.
//
// REQUIRES:
//   variant is not null.
//
bool DecodeUnaryVariant(Variant* variant);

// Copies a variant between CPU<->GPU, or between GPU<->GPU.
// The variant 'from' must have a registered DeviceCopyFn for the
// given direction.  The returned variant 'to' will have
// (some subset of its) tensors stored on destination according to the
// registered DeviceCopyFn function for the given direction.  Returns
// an Internal error if the Variant does not have a registered
// DeviceCopyFn function for the given direction, or if initiating the
// copy fails.
//
// REQUIRES:
//   'to' is not null.
//
Status VariantDeviceCopy(
    const VariantDeviceCopyDirection direction, const Variant& from,
    Variant* to,
    const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn);

// Sets *v_out = unary_op(v).  The variant v must have a registered
// UnaryOp function for the given Device.  Returns an Internal error
// if v does not have a registered unary_op function for this device, or if
// UnaryOp fails.
//
// REQUIRES:
//   v_out is not null.
//
template <typename Device>
Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
                      Variant* v_out) {
  const string& device = DeviceName<Device>::value;
  UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
      UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
  if (unary_op_fn == nullptr) {
    return errors::Internal(
        "No unary variant unary_op function found for unary variant op enum: ",
        op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
  }
  return (*unary_op_fn)(ctx, v, v_out);
}

// Sets *out = binary_op(a, b).  The variants a and b must be the same type
// and have a registered binary_op function for the given Device.  Returns an
// Internal error if a and b are not the same type_name or if
// if a does not have a registered op function for this device, or if
// BinaryOp fails.
//
// REQUIRES:
//   out is not null.
//
template <typename Device>
Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
                        const Variant& a, const Variant& b, Variant* out) {
  if (a.TypeId() != b.TypeId()) {
    return errors::Internal(
        "BianryOpVariants: Variants a and b have different "
        "type ids.  Type names: '",
        a.TypeName(), "' vs. '", b.TypeName(), "'");
  }
  const string& device = DeviceName<Device>::value;
  UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
      UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
  if (binary_op_fn == nullptr) {
    return errors::Internal(
        "No unary variant binary_op function found for binary variant op "
        "enum: ",
        op, " Variant type_name: '", a.TypeName(), "' for device type: ",
        device);
  }
  return (*binary_op_fn)(ctx, a, b, out);
}

namespace variant_op_registry_fn_registration {

template <typename T>
class UnaryVariantShapeRegistration {
 public:
  typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn;

  UnaryVariantShapeRegistration(const TypeIndex& type_index,
                                const LocalVariantShapeFn& shape_fn) {
    const string type_index_name = port::MaybeAbiDemangle(type_index.name());
    UnaryVariantOpRegistry::Global()->RegisterShapeFn(
        type_index,
        [type_index_name, shape_fn](const Variant& v,
                                    TensorShape* s) -> Status {
          const T* t = v.get<T>();
          if (t == nullptr) {
            return errors::Internal(
                "VariantShapeFn: Could not access object, type_index: ",
                type_index_name);
          }
          return shape_fn(*t, s);
        });
  }
};

template <typename T>
class UnaryVariantDecodeRegistration {
 public:
  UnaryVariantDecodeRegistration(const string& type_name) {
    // The Variant is passed by pointer because it should be
    // mutable: get below may Decode the variant, which
    // is a self-mutating behavior.  The variant is not modified in
    // any other way.
    UnaryVariantOpRegistry::Global()->RegisterDecodeFn(
        type_name, [type_name](Variant* v) -> bool {
          DCHECK_NE(v, nullptr);
          VariantTensorDataProto* t = v->get<VariantTensorDataProto>();
          if (t == nullptr) {
            return false;
          }
          Variant decoded = T();
          VariantTensorData data(std::move(*t));
          if (!decoded.Decode(std::move(data))) {
            return false;
          }
          std::swap(decoded, *v);
          return true;
        });
  }
};

template <typename T>
class UnaryVariantDeviceCopyRegistration {
 public:
  typedef std::function<Status(const T& t, T* t_out,
                               UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
      LocalVariantDeviceCopyFn;
  UnaryVariantDeviceCopyRegistration(
      const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
      const LocalVariantDeviceCopyFn& device_copy_fn) {
    const string type_index_name = port::MaybeAbiDemangle(type_index.name());
    UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
        direction, type_index,
        [type_index_name, device_copy_fn](
            const Variant& from, Variant* to,
            UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
                device_copy_tensor_fn) -> Status {
          DCHECK_NE(to, nullptr);
          *to = T();
          if (from.get<T>() == nullptr) {
            return errors::Internal(
                "VariantCopyToGPUFn: Could not access object, type_index: ",
                type_index_name);
          }
          const T& t = *from.get<T>();
          T* t_out = to->get<T>();
          return device_copy_fn(t, t_out, device_copy_tensor_fn);
        });
  }
};

template <typename T>
class UnaryVariantUnaryOpRegistration {
  typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
      LocalVariantUnaryOpFn;

 public:
  UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
                                  const TypeIndex& type_index,
                                  const LocalVariantUnaryOpFn& unary_op_fn) {
    const string type_index_name = port::MaybeAbiDemangle(type_index.name());
    UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
        op, device, type_index,
        [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
                                       Variant* v_out) -> Status {
          DCHECK_NE(v_out, nullptr);
          *v_out = T();
          if (v.get<T>() == nullptr) {
            return errors::Internal(
                "VariantUnaryOpFn: Could not access object, type_index: ",
                type_index_name);
          }
          const T& t = *v.get<T>();
          T* t_out = v_out->get<T>();
          return unary_op_fn(ctx, t, t_out);
        });
  }
};

template <typename T>
class UnaryVariantBinaryOpRegistration {
  typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
                               T* out)>
      LocalVariantBinaryOpFn;

 public:
  UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
                                   const TypeIndex& type_index,
                                   const LocalVariantBinaryOpFn& binary_op_fn) {
    const string type_index_name = port::MaybeAbiDemangle(type_index.name());
    UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
        op, device, type_index,
        [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
                                        const Variant& b,
                                        Variant* out) -> Status {
          DCHECK_NE(out, nullptr);
          *out = T();
          if (a.get<T>() == nullptr) {
            return errors::Internal(
                "VariantBinaryOpFn: Could not access object 'a', type_index: ",
                type_index_name);
          }
          if (b.get<T>() == nullptr) {
            return errors::Internal(
                "VariantBinaryOpFn: Could not access object 'b', type_index: ",
                type_index_name);
          }
          const T& t_a = *a.get<T>();
          const T& t_b = *b.get<T>();
          T* t_out = out->get<T>();
          return binary_op_fn(ctx, t_a, t_b, t_out);
        });
  }
};

};  // namespace variant_op_registry_fn_registration

// Register a unary shape variant function with the signature:
//    Status ShapeFn(const T& t, TensorShape* s);
// to Variants having TypeIndex type_index.
#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \
  REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(             \
      __COUNTER__, T, MakeTypeIndex<T>(), shape_function)

#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \
                                                          shape_function)     \
  REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function)

#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index,         \
                                                   shape_function)             \
  static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \
      register_unary_variant_op_shape_registration_fn_##ctr(type_index,        \
                                                            shape_function)

// Register a unary decode variant function for the given type.
#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
  REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)

#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \
  REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)

#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)        \
  static variant_op_registry_fn_registration::UnaryVariantDecodeRegistration< \
      T>                                                                      \
      register_unary_variant_op_decoder_fn_##ctr(type_name)

// ****** NOTE ******
// FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
// ****** NOTE ******
//
// Register a device copy variant function for the given copy
// direction and type; where direction is the enum
// VariantDeviceCopyDirection, and the device_copy_fn has signature:
//
//   Status device_copy_fn(
//     const T& t, T* t_out,
//     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier);
//
// And device_copy_fn calls copier 0 or more times.  For details on
// the behavior of the copier function, see the comments at the
// declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn.
//
// Note, the device_copy_fn may choose to keep some tensors
// on host, e.g. by assigning to->tensor = from.tensor (assuming
// from.tensor is already on host); or by setting
//   to->tensor = Tensor(cpu_allocator(), ...)
// and manually updating its values.
//
// If this is the case, the CopyFns for HOST_TO_DEVICE,
// DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host
// copies in a consistent manner.  For example, one must always
// manually copy any "always on host" tensors in all directions instead of e.g.
//   - performing a host-to-host copy in one direction,
//   - using the provided copier function in the reverse direction.
// Doing the latter will cause program failures.
//
// ****** NOTE ******
// FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
// ****** NOTE ******
#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction,   \
                                                             device_copy_fn) \
  INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER(          \
      __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn)

#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
    ctr, T, direction, type_index, device_copy_fn)                        \
  INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ(              \
      ctr, T, direction, type_index, device_copy_fn)

#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
    ctr, T, direction, type_index, device_copy_fn)                 \
  static variant_op_registry_fn_registration::                     \
      UnaryVariantDeviceCopyRegistration<T>                        \
          register_unary_variant_op_device_copy_fn_##ctr(          \
              direction, type_index, device_copy_fn)

// Register a unary unary_op variant function with the signature:
//    Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
// to Variants having TypeIndex type_index, for device string device,
// for UnaryVariantOp enum op.
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T,     \
                                                 unary_op_function) \
  REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(             \
      __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function)

#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(       \
    ctr, op, device, T, type_index, unary_op_function)              \
  REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
                                                type_index, unary_op_function)

#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(                         \
    ctr, op, device, T, type_index, unary_op_function)                         \
  static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
      T>                                                                       \
      register_unary_variant_op_decoder_fn_##ctr(op, device, type_index,       \
                                                 unary_op_function)

// Register a binary_op variant function with the signature:
//    Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
// to Variants having TypeIndex type_index, for device string device,
// for BinaryVariantOp enum OP.
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T,      \
                                                  binary_op_function) \
  REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER(              \
      __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function)

#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
    ctr, op, device, T, type_index, binary_op_function)        \
  REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(              \
      ctr, op, device, T, type_index, binary_op_function)

#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(                      \
    ctr, op, device, T, type_index, binary_op_function)                      \
  static variant_op_registry_fn_registration::                               \
      UnaryVariantBinaryOpRegistration<T>                                    \
          register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
                                                     binary_op_function)

}  // end namespace tensorflow

#endif  // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_