From b01ea7a51c07f6d2988d7f2aa117374591d1e25a Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Thu, 4 Oct 2018 14:18:58 -0700 Subject: Rename "Inliner" to "MapInliner". PiperOrigin-RevId: 215801897 --- tensorflow/compiler/xla/service/BUILD | 69 ++++---- tensorflow/compiler/xla/service/cpu/BUILD | 2 +- .../compiler/xla/service/cpu/cpu_compiler.cc | 4 +- tensorflow/compiler/xla/service/inliner.cc | 125 -------------- tensorflow/compiler/xla/service/inliner.h | 39 ----- tensorflow/compiler/xla/service/inliner_test.cc | 181 --------------------- tensorflow/compiler/xla/service/interpreter/BUILD | 2 +- .../compiler/xla/service/interpreter/compiler.cc | 2 +- tensorflow/compiler/xla/service/map_inliner.cc | 124 ++++++++++++++ tensorflow/compiler/xla/service/map_inliner.h | 39 +++++ .../compiler/xla/service/map_inliner_test.cc | 181 +++++++++++++++++++++ 11 files changed, 382 insertions(+), 386 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/inliner.cc delete mode 100644 tensorflow/compiler/xla/service/inliner.h delete mode 100644 tensorflow/compiler/xla/service/inliner_test.cc create mode 100644 tensorflow/compiler/xla/service/map_inliner.cc create mode 100644 tensorflow/compiler/xla/service/map_inliner.h create mode 100644 tensorflow/compiler/xla/service/map_inliner_test.cc (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 2f8bab0614..4797cf3330 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1841,42 +1841,6 @@ tf_cc_test( ], ) -cc_library( - name = "inliner", - srcs = ["inliner.cc"], - hdrs = ["inliner.h"], - deps = [ - ":hlo", - ":hlo_pass", - ":hlo_query", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_test( - name = "inliner_test", - srcs = ["inliner_test.cc"], - deps = [ - ":cpu_plugin", - ":hlo", - ":hlo_matchers", - ":inliner", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "@com_google_absl//absl/memory", - ], -) - cc_library( name = "computation_placer", srcs = ["computation_placer.cc"], @@ -3492,6 +3456,39 @@ cc_library( deps = ["//tensorflow/core:lib"], ) +cc_library( + name = "map_inliner", + srcs = ["map_inliner.cc"], + hdrs = ["map_inliner.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "map_inliner_test", + srcs = ["map_inliner_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":map_inliner", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/memory", + ], +) + tf_cc_test( name = "hlo_casting_utils_test", srcs = ["hlo_casting_utils_test.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ae4c6e962d..58abb330a6 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -94,6 +94,7 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -127,7 +128,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", - "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index afc94f2185..5834f67285 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -86,8 +86,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/scatter_expander.h" @@ -249,7 +249,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); - pipeline.AddPass(); + pipeline.AddPass(); // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc deleted file mode 100644 index 50c408f5bb..0000000000 --- a/tensorflow/compiler/xla/service/inliner.cc +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/inliner.h" - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_query.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -// InlinerVisitor traverses the HLO computation and inlines maps. -class InlinerVisitor : public DfsHloVisitorWithDefault { - public: - explicit InlinerVisitor(HloComputation* computation) - : computation_(computation) {} - - // Default visitor action is to do nothing and return OK. - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - - Status HandleMap(HloInstruction* map) override; - - // Runs the visitor on a computation. - StatusOr Run(HloComputation* computation); - - private: - // Current HloComputation instance the InlinerVisitor is traversing. - HloComputation* computation_; - - // Whether algebraic simplification has occurred. - bool changed_ = false; -}; - -StatusOr InlinerVisitor::Run(HloComputation* computation) { - changed_ = false; - computation_ = computation; - TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); - return changed_; -} - -Status InlinerVisitor::HandleMap(HloInstruction* map) { - HloComputation* function = map->to_apply(); - HloInstruction& root = *function->root_instruction(); - // TODO(b/29249531): Add DCE pass to remove unused HloComputations. - // Only inlining functions that are simply a single operation until a better - // profitability model for inlining is defined. - if (hlo_query::AllOperandsAreParameters(root)) { - if (root.opcode() == HloOpcode::kFusion || - root.opcode() == HloOpcode::kTrace) { - // Cloning not supported for these instructions. - return Status::OK(); - } - VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function " - << root.ToShortString(); - if (root.opcode() == HloOpcode::kParameter) { - // If the root is a parameter, then use the corresponding operand as the - // result of the computation. - TF_RETURN_IF_ERROR( - map->ReplaceAllUsesWith(map->operands()[root.parameter_number()])); - TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map)); - } else if (root.opcode() == HloOpcode::kConstant) { - // If the input is a constant then the shape of the constant could be - // different than the map shape. Hence, a broadcast is needed, else the - // cloned operand with new shape and operands work. - // - // The constant is in an embedded computation and needs to be recreated - // as part of the computation that the broadcast is inserted into. - HloInstruction* constant = computation_->AddInstruction(root.Clone()); - HloInstruction* placed_instruction = computation_->AddInstruction( - HloInstruction::CreateBroadcast(map->shape(), constant, {})); - TF_RETURN_IF_ERROR( - computation_->ReplaceInstruction(map, placed_instruction)); - } else { - std::vector params; - for (int64 o = 0; o < root.operands().size(); o++) { - params.push_back(map->operands()[root.operand(o)->parameter_number()]); - } - HloInstruction* placed_instruction = computation_->AddInstruction( - root.CloneWithNewOperands(map->shape(), params)); - TF_RETURN_IF_ERROR( - computation_->ReplaceInstruction(map, placed_instruction)); - } - changed_ = true; - return Status::OK(); - } - - return Status::OK(); -} - -StatusOr Inliner::Run(HloModule* module) { - InlinerVisitor visitor(/*computation=*/nullptr); - bool changed = false; - for (HloComputation* computation : module->computations()) { - TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation)); - changed |= computation_changed; - } - return changed; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h deleted file mode 100644 index e20af08fb7..0000000000 --- a/tensorflow/compiler/xla/service/inliner.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { - -// A pass which performs inlining. Which can result, for example, in functions -// that were previously being mapped by Map instead directly applied to the -// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)). -class Inliner : public HloModulePass { - public: - ~Inliner() override = default; - absl::string_view name() const override { return "inline"; } - - // Run inlining on the given computation. Returns whether the computation was - // changed. - StatusOr Run(HloModule* module) override; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc deleted file mode 100644 index 98e0f2cfd7..0000000000 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ /dev/null @@ -1,181 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/inliner.h" - -#include -#include - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace op = xla::testing::opcode_matchers; - -namespace xla { -namespace { - -using InlinerTest = HloVerifiedTestBase; - -// Test that `map` with `max` is transformed to `max` -TEST_F(InlinerTest, MapMax) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - - auto max_builder = HloComputation::Builder(TestName()); - auto param1 = max_builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "x")); - auto param2 = max_builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "y")); - max_builder.AddInstruction(HloInstruction::CreateBinary( - param1->shape(), HloOpcode::kMaximum, param1, param2)); - auto max_f32 = max_builder.Build(); - - auto builder = HloComputation::Builder("MapMaxFunction"); - auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({4, 3, 2, 1}))); - builder.AddInstruction( - HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); - - auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); - hlo_module->AddEmbeddedComputation(std::move(max_f32)); - hlo_module->AddEntryComputation(std::move(computation)); - - Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); - EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), - op::Maximum(lhs, rhs)); - - // Verify execution on CPU. - auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); - auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); - EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); -} - -// Test that `constant` function is changed to `broadcast`. -TEST_F(InlinerTest, MapConstant) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - - auto const2_builder = HloComputation::Builder(TestName()); - auto param1 = const2_builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "x")); - (void)param1; - const2_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); - auto const2_f32 = const2_builder.Build(); - - auto builder = HloComputation::Builder("MapConstFunction"); - auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); - builder.AddInstruction( - HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); - - auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); - hlo_module->AddEmbeddedComputation(std::move(const2_f32)); - hlo_module->AddEntryComputation(std::move(computation)); - HloInstruction* root = hlo_module->entry_computation()->root_instruction(); - Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); - root = hlo_module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); - - // Verify execution on CPU. - auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); - auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); -} - -TEST_F(InlinerTest, MapSubtractOppositeOrder) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - - // Note that the parameter ordinals are in the opposite order to their - // position as operands - auto max_builder = HloComputation::Builder(TestName()); - auto param1 = max_builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "x")); - auto param2 = max_builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "y")); - max_builder.AddInstruction(HloInstruction::CreateBinary( - param1->shape(), HloOpcode::kSubtract, param1, param2)); - auto max_f32 = max_builder.Build(); - - auto builder = HloComputation::Builder("MapSubFunction"); - auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({4, 3, 2, 1}))); - builder.AddInstruction( - HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); - - auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); - hlo_module->AddEmbeddedComputation(std::move(max_f32)); - hlo_module->AddEntryComputation(std::move(computation)); - - Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); - EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), - op::Subtract(rhs, lhs)); - - // Verify execution on CPU. - auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); - auto expected = LiteralUtil::CreateR1({3, 1, -1, -3}); - EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); -} - -TEST_F(InlinerTest, MapParameter) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - - auto param_builder = HloComputation::Builder(TestName()); - param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0")); - param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1")); - auto param_f32 = param_builder.Build(); - - auto builder = HloComputation::Builder("MapParamFunction"); - auto lhs = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); - auto rhs = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); - builder.AddInstruction( - HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get())); - - auto computation = builder.Build(); - auto hlo_module = CreateNewVerifiedModule(); - hlo_module->AddEmbeddedComputation(std::move(param_f32)); - hlo_module->AddEntryComputation(std::move(computation)); - - Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); - EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs); - - // Verify execution on CPU. - auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); - auto expected = LiteralUtil::CreateR0(4); - EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 146c9052f1..1484e14df1 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -45,8 +45,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", - "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:layout_assignment", + "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 27fe89375d..7c79eb7d79 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -28,9 +28,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" -#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/map_inliner.cc b/tensorflow/compiler/xla/service/map_inliner.cc new file mode 100644 index 0000000000..2200ef054a --- /dev/null +++ b/tensorflow/compiler/xla/service/map_inliner.cc @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/map_inliner.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +// MapInlinerVisitor traverses the HLO computation and inlines maps. +class MapInlinerVisitor : public DfsHloVisitorWithDefault { + public: + explicit MapInlinerVisitor(HloComputation* computation) + : computation_(computation) {} + + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleMap(HloInstruction* map) override; + + // Runs the visitor on a computation. + StatusOr Run(HloComputation* computation); + + private: + // Current HloComputation instance the MapInlinerVisitor is traversing. + HloComputation* computation_; + + // Whether algebraic simplification has occurred. + bool changed_ = false; +}; + +StatusOr MapInlinerVisitor::Run(HloComputation* computation) { + changed_ = false; + computation_ = computation; + TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); + return changed_; +} + +Status MapInlinerVisitor::HandleMap(HloInstruction* map) { + HloComputation* function = map->to_apply(); + HloInstruction& root = *function->root_instruction(); + // Only inlining functions that are simply a single operation until a better + // profitability model for inlining is defined. + if (hlo_query::AllOperandsAreParameters(root)) { + if (root.opcode() == HloOpcode::kFusion || + root.opcode() == HloOpcode::kTrace) { + // Cloning not supported for these instructions. + return Status::OK(); + } + VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function " + << root.ToShortString(); + if (root.opcode() == HloOpcode::kParameter) { + // If the root is a parameter, then use the corresponding operand as the + // result of the computation. + TF_RETURN_IF_ERROR( + map->ReplaceAllUsesWith(map->operands()[root.parameter_number()])); + TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map)); + } else if (root.opcode() == HloOpcode::kConstant) { + // If the input is a constant then the shape of the constant could be + // different than the map shape. Hence, a broadcast is needed, else the + // cloned operand with new shape and operands work. + // + // The constant is in an embedded computation and needs to be recreated + // as part of the computation that the broadcast is inserted into. + HloInstruction* constant = computation_->AddInstruction(root.Clone()); + HloInstruction* placed_instruction = computation_->AddInstruction( + HloInstruction::CreateBroadcast(map->shape(), constant, {})); + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(map, placed_instruction)); + } else { + std::vector params; + for (int64 o = 0; o < root.operands().size(); o++) { + params.push_back(map->operands()[root.operand(o)->parameter_number()]); + } + HloInstruction* placed_instruction = computation_->AddInstruction( + root.CloneWithNewOperands(map->shape(), params)); + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(map, placed_instruction)); + } + changed_ = true; + return Status::OK(); + } + + return Status::OK(); +} + +StatusOr MapInliner::Run(HloModule* module) { + MapInlinerVisitor visitor(/*computation=*/nullptr); + bool changed = false; + for (HloComputation* computation : module->computations()) { + TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation)); + changed |= computation_changed; + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/map_inliner.h b/tensorflow/compiler/xla/service/map_inliner.h new file mode 100644 index 0000000000..b679118118 --- /dev/null +++ b/tensorflow/compiler/xla/service/map_inliner.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which performs map inlining. This replaces kMap instructions with +// their equivalent sequence of array operations. For example: +// map({X, Y}, add) -> add(X, Y)). +class MapInliner : public HloModulePass { + public: + ~MapInliner() override = default; + absl::string_view name() const override { return "map-inline"; } + + // Run map inlining on the given computation. Returns whether the computation + // was changed. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/map_inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc new file mode 100644 index 0000000000..84059dd0f7 --- /dev/null +++ b/tensorflow/compiler/xla/service/map_inliner_test.cc @@ -0,0 +1,181 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/map_inliner.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using MapInlinerTest = HloVerifiedTestBase; + +// Test that `map` with `max` is transformed to `max` +TEST_F(MapInlinerTest, MapMax) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto max_builder = HloComputation::Builder(TestName()); + auto param1 = max_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "x")); + auto param2 = max_builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "y")); + max_builder.AddInstruction(HloInstruction::CreateBinary( + param1->shape(), HloOpcode::kMaximum, param1, param2)); + auto max_f32 = max_builder.Build(); + + auto builder = HloComputation::Builder("MapMaxFunction"); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({4, 3, 2, 1}))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewModule(); + hlo_module->AddEmbeddedComputation(std::move(max_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), + op::Maximum(lhs, rhs)); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); + auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); +} + +// Test that `constant` function is changed to `broadcast`. +TEST_F(MapInlinerTest, MapConstant) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto const2_builder = HloComputation::Builder(TestName()); + auto param1 = const2_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "x")); + (void)param1; + const2_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + auto const2_f32 = const2_builder.Build(); + + auto builder = HloComputation::Builder("MapConstFunction"); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewModule(); + hlo_module->AddEmbeddedComputation(std::move(const2_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + HloInstruction* root = hlo_module->entry_computation()->root_instruction(); + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + root = hlo_module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); + auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); +} + +TEST_F(MapInlinerTest, MapSubtractOppositeOrder) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + // Note that the parameter ordinals are in the opposite order to their + // position as operands + auto max_builder = HloComputation::Builder(TestName()); + auto param1 = max_builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "x")); + auto param2 = max_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "y")); + max_builder.AddInstruction(HloInstruction::CreateBinary( + param1->shape(), HloOpcode::kSubtract, param1, param2)); + auto max_f32 = max_builder.Build(); + + auto builder = HloComputation::Builder("MapSubFunction"); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({4, 3, 2, 1}))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewModule(); + hlo_module->AddEmbeddedComputation(std::move(max_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), + op::Subtract(rhs, lhs)); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); + auto expected = LiteralUtil::CreateR1({3, 1, -1, -3}); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); +} + +TEST_F(MapInlinerTest, MapParameter) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto param_builder = HloComputation::Builder(TestName()); + param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0")); + param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1")); + auto param_f32 = param_builder.Build(); + + auto builder = HloComputation::Builder("MapParamFunction"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEmbeddedComputation(std::move(param_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); + auto expected = LiteralUtil::CreateR0(4); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); +} + +} // namespace +} // namespace xla -- cgit v1.2.3