diff options
author | Peter Hawkins <phawkins@google.com> | 2017-06-02 11:04:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-02 11:08:26 -0700 |
commit | 0f2db739163809782049b2c956355506c88c77e5 (patch) | |
tree | 44c7f8f83d005f5e56ab82c6f9e6f25d3d0eddc8 /tensorflow/compiler/jit/union_find.h | |
parent | d5421cf58e4b84832974e51ebc2c3a11ad86efb7 (diff) |
[TF:XLA] Split union-find implementation in mark_for_compilation_pass.cc into a separate library, make it more generic.
PiperOrigin-RevId: 157850985
Diffstat (limited to 'tensorflow/compiler/jit/union_find.h')
-rw-r--r-- | tensorflow/compiler/jit/union_find.h | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/union_find.h b/tensorflow/compiler/jit/union_find.h new file mode 100644 index 0000000000..a1a7a6a4d0 --- /dev/null +++ b/tensorflow/compiler/jit/union_find.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ +#define TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ + +namespace tensorflow { + +// Union-Find data structure. +// Each cluster has an associated value; when merging clusters we can control +// which value becomes the representative of the merged clusters. Values must be +// copyable. +template <typename T> +class UnionFind { + public: + UnionFind() : rank_(0), size_(1), parent_(nullptr) {} + + // Returns the number of elements in a cluster. + int Size() { return FindRoot()->size_; } + + // Merges this cluster with 'other'. This cluster's value becomes + // the value of the merged cluster; the value of 'other' is ignored. + void Merge(UnionFind* other); + + // Each cluster has an associated value. Retrieves the value associated + // with this cluster. + T& Get() { return FindRoot()->value_; } + + private: + // Finds the root element of the cluster. Performs path compression. + UnionFind* FindRoot(); + + int rank_; + int size_; // Size of the cluster. + UnionFind* parent_; + T value_; +}; + +template <typename T> +void UnionFind<T>::Merge(UnionFind* other) { + UnionFind<T>* a = FindRoot(); + UnionFind<T>* b = other->FindRoot(); + if (a == b) return; + if (a->rank_ > b->rank_) { + b->parent_ = a; + a->size_ += b->size_; + return; + } + + a->parent_ = b; + if (a->rank_ == b->rank_) { + b->rank_++; + } + b->value_ = a->value_; + b->size_ += a->size_; +} + +template <typename T> +UnionFind<T>* UnionFind<T>::FindRoot() { + if (!parent_) return this; + // Path compression: update intermediate nodes to point to the root of the + // equivalence class. + parent_ = parent_->FindRoot(); + return parent_; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ |