aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_casting_utils.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-04 14:52:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-04 14:55:12 -0700
commit142ccf3666e07d011aa83fdd6be8c17f721fbc99 (patch)
tree923e1406474314f68d2d0ded369b39cc988ca976 /tensorflow/compiler/xla/service/hlo_casting_utils.h
parent06c4fb61f269e18ca2f4b9a73d1b92e48bd095bf (diff)
Add rip-offs of LLVM's cast, dyn_cast, cast_or_null, dyn_cast_or_null in preparation to split HloInstruction into subclasses. This initial implementation uses C++ dynamic_cast, so it also adds vtable to HloInstruction.
PiperOrigin-RevId: 199199109
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_casting_utils.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_casting_utils.h101
1 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h
new file mode 100644
index 0000000000..b15f1f24c6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Casting utilitiy functions for HLO instructions.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+namespace xla {
+
+template <class T>
+using EnableIfDerivedFromHlo =
+ typename std::enable_if<std::is_base_of<HloInstruction, T>::value>::type;
+
+// TODO(b/93238915): Switch implementation from C++'s dynamic_cast to LLVM-like
+// RTTI if it turns out to be a performance issue.
+// Casts an HloInstruction pointer to one of its subclasses, dies if argument is
+// nullptr or runtime information does not match.
+//
+// Similar to LLVM's cast.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+const T* Cast(const HloInstruction* instruction) {
+ CHECK(instruction != nullptr);
+ const T* casted = dynamic_cast<const T*>(instruction);
+ CHECK(casted != nullptr);
+ return casted;
+}
+
+// Non-const overload of Cast.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+T* Cast(HloInstruction* instruction) {
+ return const_cast<T*>(
+ Cast<T>(const_cast<const HloInstruction*>(instruction)));
+}
+
+// Works just like the Cast, except that it allows for a null pointer as an
+// argument which it then propagates.
+//
+// Similar to LLVM's cast_or_null.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+const T* CastOrNull(const HloInstruction* instruction) {
+ return instruction != nullptr ? Cast<T>(instruction) : nullptr;
+}
+
+// Non-const overload of CastOrNull.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+T* CastOrNull(HloInstruction* instruction) {
+ return const_cast<T*>(
+ CastOrNull<T>(const_cast<const HloInstruction*>(instruction)));
+}
+
+// Casts an HloInstruction pointer to one of its subclasses, dies if argument is
+// nullptr, returns nullptr if runtime information does not match.
+//
+// Similar to LLVM's dyn_cast.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+const T* DynCast(const HloInstruction* instruction) {
+ CHECK(instruction != nullptr);
+ return dynamic_cast<const T*>(instruction);
+}
+
+// Non-const overload of DynCast.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+T* DynCast(HloInstruction* instruction) {
+ return const_cast<T*>(
+ DynCast<T>(const_cast<const HloInstruction*>(instruction)));
+}
+
+// Works just like the DynCast, except that it allows for a null pointer as an
+// argument which it then propagates.
+//
+// Similar to LLVM's dyn_cast_or_null.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+const T* DynCastOrNull(const HloInstruction* instruction) {
+ return instruction != nullptr ? DynCast<T>(instruction) : nullptr;
+}
+
+// Non-const overload of DynCastOrNull.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+T* DynCastOrNull(HloInstruction* instruction) {
+ return const_cast<T*>(
+ DynCastOrNull<T>(const_cast<const HloInstruction*>(instruction)));
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_