diff options
author | 2017-08-21 12:59:30 -0700 | |
---|---|---|
committer | 2017-08-21 13:03:00 -0700 | |
commit | a6e811ff29406aa582b14e2e07d2881c81d3e4a8 (patch) | |
tree | e3eb6a57cfc06911deb71e41daeea7faf527dfd9 /tensorflow/compiler/xla/service/logical_buffer_analysis.h | |
parent | 3c68189114618d4c88a3694984dead304668158d (diff) |
[XLA] Separate logical buffer creation and ownership out of tuple pointer analysis.
PiperOrigin-RevId: 165963591
Diffstat (limited to 'tensorflow/compiler/xla/service/logical_buffer_analysis.h')
-rw-r--r-- | tensorflow/compiler/xla/service/logical_buffer_analysis.h | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h new file mode 100644 index 0000000000..de9fe1b0a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace xla { +// A class to create all the logical buffers defined by the HLO ops in a module. +class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { + public: + // Runs points-to analysis on 'module'. + static StatusOr<std::unique_ptr<LogicalBufferAnalysis>> Run( + const HloModule* module); + + // Returns the logical buffer with the given ID. + LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const; + + // Returns the logical buffer that represents the output of a given HLO + // at a given index. + LogicalBuffer& GetBuffer(HloInstruction* instruction, + const ShapeIndex& index) const; + + const std::vector<std::unique_ptr<LogicalBuffer>>& logical_buffers() const { + return logical_buffers_; + } + LogicalBuffer::Id num_logical_buffers() const { return next_buffer_id_; } + + private: + explicit LogicalBufferAnalysis(const HloModule* module) : module_(module) {} + Status Analyze(); + + // The module this analysis is performed on. + const HloModule* module_; + + // Create a new logical buffer and return a reference to it. The newly created + // buffer is stored in an internal vector of LogicalBuffers and can be + // accessed with GetBuffer. + void NewLogicalBuffer(HloInstruction* instruction, const ShapeIndex& index); + + Status DefaultAction(HloInstruction* hlo_instruction) override; + Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice<HloInstruction*> operands) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) override; + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleCopy(HloInstruction* copy) override; + Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) override; + + // A map from the buffer ID to the logical buffer + std::vector<std::unique_ptr<LogicalBuffer>> logical_buffers_; + + struct Hasher { + size_t operator()( + std::pair<const HloInstruction*, const ShapeIndex> p) const { + size_t inst_hash = tensorflow::hash<const HloInstruction*>()(p.first); + for (auto index = p.second.begin(); index != p.second.end(); ++index) { + inst_hash = tensorflow::Hash64Combine(*index, inst_hash); + } + return inst_hash; + } + }; + + // A map from an hlo + shape index to the logical buffer representing + // the appropriate output. + std::unordered_map<std::pair<const HloInstruction*, const ShapeIndex>, + LogicalBuffer*, Hasher> + output_buffers_; + + // The ID of the next logical buffer created. + LogicalBuffer::Id next_buffer_id_ = 0; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ |