diff options
author | Mark Heffernan <meheff@google.com> | 2017-06-29 16:51:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-29 16:55:06 -0700 |
commit | 38faead386d9d19b9af10150ac9d0cddd7b788e8 (patch) | |
tree | 1125922090b009a45fc7328d5736fe2fcffbb78d /tensorflow/compiler/xla/service/hlo_reachability.h | |
parent | 4eb71a7f1aa01639a57b67413e7225fc26512ead (diff) |
[XLA] Move HLO reachability into its own file and make update-able.
As part of the CL, change the underlying representation in the reachability map to BitVectors which allows efficient update by OR'ing the vectors together.
PiperOrigin-RevId: 160591849
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_reachability.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_reachability.h | 138 |
1 files changed, 138 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h new file mode 100644 index 0000000000..d7bdac9c86 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ + +#include <list> +#include <vector> + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class HloInstruction; + +// A class for computing and representing reachability between HloInstructions. +class HloReachabilityMap { + public: + // Sets up an empty reachable matrix for the full set of instructions + // specified in 'instructions'. + explicit HloReachabilityMap(const std::list<HloInstruction*>& instructions); + + // Set the reachability set of 'instruction' to the union of the reachability + // sets of 'inputs'. Upon return, IsReachable(x, instruction) where + // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true + // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from + // itself. Returns whether the reachability set of 'instruction' changed. + bool SetReachabilityToUnion( + tensorflow::gtl::ArraySlice<const HloInstruction*> inputs, + const HloInstruction* instruction); + + // Sets entry so that IsReachable(a, b) will return true + void SetReachable(const HloInstruction* a, const HloInstruction* b); + + // Returns true if "b" is reachable from "a" + bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; + + // Returns true if "b" is reachable from "a" or "a" is reachable from "b" + bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; + + private: + // A bit-vector implementation specialized for this use case which provides a + // fast bitwise OR operation not available in tensorflow::gtl::BitMap. + class BitVector { + public: + BitVector() = default; + BitVector(size_t size) + : size_(size), vector_((size + kBits - 1) / kBits, 0) {} + + // Return the bit at the given index. + bool Get(size_t index) const { + DCHECK(index >= 0 && index < size_); + return vector_[index / kBits] & (1ull << (index % kBits)); + } + + // Set the bit at the given index. + void Set(size_t index) { + DCHECK(index >= 0 && index < size_); + vector_[index / kBits] |= 1ull << (index % kBits); + } + + // Set this bitvector to the Logical OR of this bitvector and 'other'. + void OrWith(const BitVector& other) { + for (size_t i = 0; i < vector_.size(); ++i) { + vector_[i] |= other.vector_[i]; + } + } + + // Set the bitvector to all zeros. + void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); } + + bool operator==(const BitVector& other) const { + return vector_ == other.vector_; + } + bool operator!=(const BitVector& other) const { + return vector_ != other.vector_; + } + + private: + using Word = uint64; + static const size_t kBits = 64; + + // Number of bits in the bitvector. + size_t size_; + + std::vector<Word> vector_; + }; + + // Return the bitvector storing the reachability-to of the given instruction. + const BitVector& GetBitVector(const HloInstruction* instruction) const { + return bit_vectors_[GetIndex(instruction)]; + } + BitVector& GetBitVector(const HloInstruction* instruction) { + return bit_vectors_[GetIndex(instruction)]; + } + + // Return the index of the given instruction. The value is used to index into + // the vector of BitVectors and the BitVectors themselves. + int GetIndex(const HloInstruction* instruction) const { + return FindOrDie(indices_, instruction); + } + + // The number of instructions in the reachability map. + const size_t size_; + + // Dense assignment from HloInstruction* to number. These numbers index + // into the bit_vectors_ vector and into the bits within a BitVector. + tensorflow::gtl::FlatMap<const HloInstruction*, int> indices_; + + // Bitvectors holding the reachability to each instruction. The bit vector for + // instruction X includes ones for each instruction which X is reachable from. + std::vector<BitVector> bit_vectors_; + + // A temporary used by SetReachabilityToUnion to avoid an allocation with each + // call to the method. + BitVector tmp_bit_vector_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ |