diff options
author | Yunxing Dai <yunxing@google.com> | 2018-10-08 21:18:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 21:23:03 -0700 |
commit | 375c109659d2d0e6265447dffdeb460693b3cccf (patch) | |
tree | a6f09b6472cff1ade7fc91c1ff0d5e3f473da774 /tensorflow/compiler/xla/service/hlo_input_output_alias_config.h | |
parent | d58712b7fc8de0e1f87fe2ea5221bc3c85230ed3 (diff) |
[XLA] Introduce input/output alias config.
- This CL intruduces input/output alias config in HLO module that allows any HLO pass to configure it. Once the alias_config is set, each backend needs to follow the contract during execution time to make sure the input and output are indeed aliased.
- Copy insertion / buffer assignment and alias analysis has been updated to correctly honor the config and avoid any possible liveness interference.
PiperOrigin-RevId: 216299501
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_input_output_alias_config.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_input_output_alias_config.h | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h new file mode 100644 index 0000000000..02c46f65c8 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ + +#include <utility> + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +class HloModule; + +// This class specifies the alias map from output index to parameter number and +// parameter index in the entry computation. +class HloInputOutputAliasConfig { + public: + HloInputOutputAliasConfig() = default; + + explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {} + + virtual ~HloInputOutputAliasConfig() = default; + + // Sets up alias config from `output_index` to `param_index` at + // `param_number`. + Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index); + + // Returns true if the given parameter is aliased with one of the output + // buffers. + bool ParameterHasAlias(int64 param_number) const; + + // (De)Serializes an HloInputOutoutAliasConfig to/from an + // HloInputOutoutAliasProto. + HloInputOutputAliasProto ToProto() const; + + static StatusOr<HloInputOutputAliasConfig> CreateFromProto( + const HloModule* module, const HloInputOutputAliasProto& proto); + + // Returns the output index that the given parameter and parameter index is + // aliased with. A nullopt is returned if there is no output that is aliased + // with the parameter number and index. + absl::optional<ShapeIndex> GetAliasedOutput( + int64 param_number, const ShapeIndex& param_index) const; + + // Returns the number of parameter and index of the parameter buffer that the + // given output buffer index is aliased with. A nullopt is returned if there + // is no parameter is aliased with the specific output. + absl::optional<std::pair<int64, ShapeIndex>> GetAliasedParameter( + const ShapeIndex& output_index) const; + + using AliasFn = + std::function<void(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index)>; + + // Iterates through each aliased output and input. + void ForEachAlias(AliasFn fn) const; + + using AliasFnWithStatus = + std::function<Status(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index)>; + + // Verifies that the given config is valid for the given module. + // Specifically, the config's input and output should be in-bound and size of + // the aliased buffers should match. + Status Verify(const HloModule& module) const; + + Status ForEachAliasWithStatus(AliasFnWithStatus fn) const; + + string ToString() const; + + private: + // A ShapeTree which indicates the list of buffers that's expected to be + // aliased. The key on this shape tree represents the output index. The value + // is a pair of parameter number and index into the buffer. If the value is + // nullopt, it means there is no parameter aliasing for this output. + ShapeTree<absl::optional<std::pair<int64, ShapeIndex>>> alias_; +}; + +std::ostream& operator<<(std::ostream& out, + const HloInputOutputAliasConfig& config); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ |