diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-04 14:52:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-04 14:55:12 -0700 |
commit | 142ccf3666e07d011aa83fdd6be8c17f721fbc99 (patch) | |
tree | 923e1406474314f68d2d0ded369b39cc988ca976 /tensorflow/compiler/xla/service/hlo_casting_utils.h | |
parent | 06c4fb61f269e18ca2f4b9a73d1b92e48bd095bf (diff) |
Add rip-offs of LLVM's cast, dyn_cast, cast_or_null, dyn_cast_or_null in preparation to split HloInstruction into subclasses. This initial implementation uses C++ dynamic_cast, so it also adds vtable to HloInstruction.
PiperOrigin-RevId: 199199109
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_casting_utils.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_casting_utils.h | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h new file mode 100644 index 0000000000..b15f1f24c6 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Casting utilitiy functions for HLO instructions. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +template <class T> +using EnableIfDerivedFromHlo = + typename std::enable_if<std::is_base_of<HloInstruction, T>::value>::type; + +// TODO(b/93238915): Switch implementation from C++'s dynamic_cast to LLVM-like +// RTTI if it turns out to be a performance issue. +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr or runtime information does not match. +// +// Similar to LLVM's cast. +template <class T, EnableIfDerivedFromHlo<T>* = nullptr> +const T* Cast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + const T* casted = dynamic_cast<const T*>(instruction); + CHECK(casted != nullptr); + return casted; +} + +// Non-const overload of Cast. +template <class T, EnableIfDerivedFromHlo<T>* = nullptr> +T* Cast(HloInstruction* instruction) { + return const_cast<T*>( + Cast<T>(const_cast<const HloInstruction*>(instruction))); +} + +// Works just like the Cast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's cast_or_null. +template <class T, EnableIfDerivedFromHlo<T>* = nullptr> +const T* CastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? Cast<T>(instruction) : nullptr; +} + +// Non-const overload of CastOrNull. +template <class T, EnableIfDerivedFromHlo<T>* = nullptr> +T* CastOrNull(HloInstruction* instruction) { + return const_cast<T*>( + CastOrNull<T>(const_cast<const HloInstruction*>(instruction))); +} + +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr, returns nullptr if runtime information does not match. +// +// Similar to LLVM's dyn_cast. +template <class T, EnableIfDerivedFromHlo<T>* = nullptr> +const T* DynCast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + return dynamic_cast<const T*>(instruction); +} + +// Non-const overload of DynCast. +template <class T, EnableIfDerivedFromHlo<T>* = nullptr> +T* DynCast(HloInstruction* instruction) { + return const_cast<T*>( + DynCast<T>(const_cast<const HloInstruction*>(instruction))); +} + +// Works just like the DynCast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's dyn_cast_or_null. +template <class T, EnableIfDerivedFromHlo<T>* = nullptr> +const T* DynCastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? DynCast<T>(instruction) : nullptr; +} + +// Non-const overload of DynCastOrNull. +template <class T, EnableIfDerivedFromHlo<T>* = nullptr> +T* DynCastOrNull(HloInstruction* instruction) { + return const_cast<T*>( + DynCastOrNull<T>(const_cast<const HloInstruction*>(instruction))); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ |