aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/union_find.h
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-06-02 11:04:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-02 11:08:26 -0700
commit0f2db739163809782049b2c956355506c88c77e5 (patch)
tree44c7f8f83d005f5e56ab82c6f9e6f25d3d0eddc8 /tensorflow/compiler/jit/union_find.h
parentd5421cf58e4b84832974e51ebc2c3a11ad86efb7 (diff)
[TF:XLA] Split union-find implementation in mark_for_compilation_pass.cc into a separate library, make it more generic.
PiperOrigin-RevId: 157850985
Diffstat (limited to 'tensorflow/compiler/jit/union_find.h')
-rw-r--r--tensorflow/compiler/jit/union_find.h81
1 files changed, 81 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/union_find.h b/tensorflow/compiler/jit/union_find.h
new file mode 100644
index 0000000000..a1a7a6a4d0
--- /dev/null
+++ b/tensorflow/compiler/jit/union_find.h
@@ -0,0 +1,81 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
+#define TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
+
+namespace tensorflow {
+
+// Union-Find data structure.
+// Each cluster has an associated value; when merging clusters we can control
+// which value becomes the representative of the merged clusters. Values must be
+// copyable.
+template <typename T>
+class UnionFind {
+ public:
+ UnionFind() : rank_(0), size_(1), parent_(nullptr) {}
+
+ // Returns the number of elements in a cluster.
+ int Size() { return FindRoot()->size_; }
+
+ // Merges this cluster with 'other'. This cluster's value becomes
+ // the value of the merged cluster; the value of 'other' is ignored.
+ void Merge(UnionFind* other);
+
+ // Each cluster has an associated value. Retrieves the value associated
+ // with this cluster.
+ T& Get() { return FindRoot()->value_; }
+
+ private:
+ // Finds the root element of the cluster. Performs path compression.
+ UnionFind* FindRoot();
+
+ int rank_;
+ int size_; // Size of the cluster.
+ UnionFind* parent_;
+ T value_;
+};
+
+template <typename T>
+void UnionFind<T>::Merge(UnionFind* other) {
+ UnionFind<T>* a = FindRoot();
+ UnionFind<T>* b = other->FindRoot();
+ if (a == b) return;
+ if (a->rank_ > b->rank_) {
+ b->parent_ = a;
+ a->size_ += b->size_;
+ return;
+ }
+
+ a->parent_ = b;
+ if (a->rank_ == b->rank_) {
+ b->rank_++;
+ }
+ b->value_ = a->value_;
+ b->size_ += a->size_;
+}
+
+template <typename T>
+UnionFind<T>* UnionFind<T>::FindRoot() {
+ if (!parent_) return this;
+ // Path compression: update intermediate nodes to point to the root of the
+ // equivalence class.
+ parent_ = parent_->FindRoot();
+ return parent_;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_UNION_FIND_H_