aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/logical_buffer_analysis.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-21 12:59:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-21 13:03:00 -0700
commita6e811ff29406aa582b14e2e07d2881c81d3e4a8 (patch)
treee3eb6a57cfc06911deb71e41daeea7faf527dfd9 /tensorflow/compiler/xla/service/logical_buffer_analysis.h
parent3c68189114618d4c88a3694984dead304668158d (diff)
[XLA] Separate logical buffer creation and ownership out of tuple pointer analysis.
PiperOrigin-RevId: 165963591
Diffstat (limited to 'tensorflow/compiler/xla/service/logical_buffer_analysis.h')
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.h96
1 files changed, 96 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
new file mode 100644
index 0000000000..de9fe1b0a4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
@@ -0,0 +1,96 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace xla {
+// A class to create all the logical buffers defined by the HLO ops in a module.
+class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
+ public:
+ // Runs points-to analysis on 'module'.
+ static StatusOr<std::unique_ptr<LogicalBufferAnalysis>> Run(
+ const HloModule* module);
+
+ // Returns the logical buffer with the given ID.
+ LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const;
+
+ // Returns the logical buffer that represents the output of a given HLO
+ // at a given index.
+ LogicalBuffer& GetBuffer(HloInstruction* instruction,
+ const ShapeIndex& index) const;
+
+ const std::vector<std::unique_ptr<LogicalBuffer>>& logical_buffers() const {
+ return logical_buffers_;
+ }
+ LogicalBuffer::Id num_logical_buffers() const { return next_buffer_id_; }
+
+ private:
+ explicit LogicalBufferAnalysis(const HloModule* module) : module_(module) {}
+ Status Analyze();
+
+ // The module this analysis is performed on.
+ const HloModule* module_;
+
+ // Create a new logical buffer and return a reference to it. The newly created
+ // buffer is stored in an internal vector of LogicalBuffers and can be
+ // accessed with GetBuffer.
+ void NewLogicalBuffer(HloInstruction* instruction, const ShapeIndex& index);
+
+ Status DefaultAction(HloInstruction* hlo_instruction) override;
+ Status HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) override;
+ Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleCopy(HloInstruction* copy) override;
+ Status HandleSelect(HloInstruction* select, HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) override;
+
+ // A map from the buffer ID to the logical buffer
+ std::vector<std::unique_ptr<LogicalBuffer>> logical_buffers_;
+
+ struct Hasher {
+ size_t operator()(
+ std::pair<const HloInstruction*, const ShapeIndex> p) const {
+ size_t inst_hash = tensorflow::hash<const HloInstruction*>()(p.first);
+ for (auto index = p.second.begin(); index != p.second.end(); ++index) {
+ inst_hash = tensorflow::Hash64Combine(*index, inst_hash);
+ }
+ return inst_hash;
+ }
+ };
+
+ // A map from an hlo + shape index to the logical buffer representing
+ // the appropriate output.
+ std::unordered_map<std::pair<const HloInstruction*, const ShapeIndex>,
+ LogicalBuffer*, Hasher>
+ output_buffers_;
+
+ // The ID of the next logical buffer created.
+ LogicalBuffer::Id next_buffer_id_ = 0;
+};
+
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_