aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-10-08 21:18:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 21:23:03 -0700
commit375c109659d2d0e6265447dffdeb460693b3cccf (patch)
treea6f09b6472cff1ade7fc91c1ff0d5e3f473da774 /tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc
parentd58712b7fc8de0e1f87fe2ea5221bc3c85230ed3 (diff)
[XLA] Introduce input/output alias config.
- This CL intruduces input/output alias config in HLO module that allows any HLO pass to configure it. Once the alias_config is set, each backend needs to follow the contract during execution time to make sure the input and output are indeed aliased. - Copy insertion / buffer assignment and alias analysis has been updated to correctly honor the config and avoid any possible liveness interference. PiperOrigin-RevId: 216299501
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc184
1 files changed, 184 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc
new file mode 100644
index 0000000000..3b61ff04e6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc
@@ -0,0 +1,184 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+class HloInputOutputAliasConfigTest : public HloTestBase {
+ protected:
+ void expect_aliased(const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index,
+ const HloInputOutputAliasConfig& config) {
+ absl::optional<ShapeIndex> aliased_output =
+ config.GetAliasedOutput(param_number, param_index);
+
+ EXPECT_TRUE(aliased_output);
+ EXPECT_EQ(aliased_output.value(), output_index);
+
+ absl::optional<std::pair<int64, ShapeIndex>> aliased_param =
+ config.GetAliasedParameter(output_index);
+
+ EXPECT_TRUE(aliased_param);
+ EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index));
+ }
+
+ void expect_not_aliased(const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index,
+ const HloInputOutputAliasConfig& config) {
+ absl::optional<ShapeIndex> aliased_output =
+ config.GetAliasedOutput(param_number, param_index);
+
+ EXPECT_FALSE(aliased_output && aliased_output == output_index);
+
+ absl::optional<std::pair<int64, ShapeIndex>> aliased_param =
+ config.GetAliasedParameter(output_index);
+
+ EXPECT_FALSE(aliased_param && aliased_param->first == param_number &&
+ aliased_param->second == param_index);
+ }
+};
+
+TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) {
+ const string module_str = R"(
+HloModule TEST
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT root = (f32[], f32[]) tuple(%a, %b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+
+ HloInputOutputAliasConfig config(
+ module->entry_computation()->root_instruction()->shape());
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1,
+ /*param_index=*/{}));
+
+ expect_aliased(/*output_index=*/{0}, /*param_number=*/1,
+ /*param_index=*/{}, config);
+
+ expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
+ /*param_index=*/{}, config);
+
+ expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{}, config);
+}
+
+TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) {
+ const string module_str = R"(
+HloModule TEST
+
+ENTRY main {
+ param = (f32[], f32[]) parameter(0)
+ gte1 = f32[] get-tuple-element(%param), index=0
+ gte2 = f32[] get-tuple-element(%param), index=1
+ ROOT root = (f32[], f32[]) tuple(%gte1, %gte2)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+
+ HloInputOutputAliasConfig config(
+ module->entry_computation()->root_instruction()->shape());
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{0}));
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0,
+ /*param_index=*/{1}));
+
+ expect_aliased(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{0}, config);
+
+ expect_aliased(/*output_index=*/{1}, /*param_number=*/0,
+ /*param_index=*/{1}, config);
+
+ expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
+ /*param_index=*/{}, config);
+
+ expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{}, config);
+}
+
+TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) {
+ const string module_str = R"(
+HloModule TEST
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT root = (f32[], f32[]) tuple(%a, %b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+
+ HloInputOutputAliasConfig config(
+ module->entry_computation()->root_instruction()->shape());
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{}));
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0,
+ /*param_index=*/{}));
+
+ ASSERT_IS_NOT_OK(config.Verify(*module));
+}
+
+TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) {
+ const string module_str = R"(
+HloModule TEST
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT root = (f32[], f32[]) tuple(%a, %b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+
+ HloInputOutputAliasConfig config(
+ module->entry_computation()->root_instruction()->shape());
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{}));
+
+ ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1,
+ /*param_index=*/{}));
+}
+} // namespace
+} // namespace xla