aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/BUILD77
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_clone_context.h97
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc47
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h21
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc23
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc67
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.cc104
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.h56
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc168
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h108
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_metadata.h83
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_remover.cc149
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_remover.h56
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc432
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc87
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h58
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc74
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc78
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc401
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.h67
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc3
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc8
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h1
-rw-r--r--tensorflow/compiler/xla/shape_tree.h3
-rw-r--r--tensorflow/compiler/xla/shape_util.cc21
-rw-r--r--tensorflow/compiler/xla/shape_util.h17
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc1
41 files changed, 2252 insertions, 190 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 5472f9a637..7e4a75a6e3 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -273,7 +273,9 @@ cc_library(
hdrs = [
"dfs_hlo_visitor.h",
"dfs_hlo_visitor_with_default.h",
+ "hlo_clone_context.h",
"hlo_computation.h",
+ "hlo_domain_metadata.h",
"hlo_instruction.h",
"hlo_module.h",
"hlo_opcode.h",
@@ -415,6 +417,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -2339,6 +2342,7 @@ cc_library(
hdrs = ["hlo_cse.h"],
deps = [
":hlo",
+ ":hlo_domain_map",
":hlo_pass",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
@@ -2404,6 +2408,79 @@ tf_cc_test(
)
cc_library(
+ name = "hlo_domain_map",
+ srcs = ["hlo_domain_map.cc"],
+ hdrs = ["hlo_domain_map.h"],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "hlo_sharding_metadata",
+ srcs = ["hlo_sharding_metadata.cc"],
+ hdrs = [
+ "hlo_sharding_metadata.h",
+ ],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:shape_tree",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "hlo_domain_isolator",
+ srcs = ["hlo_domain_isolator.cc"],
+ hdrs = ["hlo_domain_isolator.h"],
+ deps = [
+ ":hlo",
+ ":hlo_graph_dumper",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ ],
+)
+
+cc_library(
+ name = "hlo_domain_remover",
+ srcs = ["hlo_domain_remover.cc"],
+ hdrs = ["hlo_domain_remover.h"],
+ deps = [
+ ":hlo",
+ ":hlo_domain_isolator",
+ ":hlo_domain_map",
+ ":hlo_graph_dumper",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_domain_test",
+ srcs = ["hlo_domain_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_domain_isolator",
+ ":hlo_domain_remover",
+ ":hlo_sharding_metadata",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
name = "hlo_element_type_converter",
srcs = ["hlo_element_type_converter.cc"],
hdrs = ["hlo_element_type_converter.h"],
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index b9d7ec9c2e..64678d9d74 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -197,6 +197,10 @@ class DfsHloVisitorBase {
return HandleElementwiseUnary(hlo);
}
+ virtual Status HandleDomain(HloInstructionPtr hlo) {
+ return HandleElementwiseUnary(hlo);
+ }
+
virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h
new file mode 100644
index 0000000000..658643b427
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_clone_context.h
@@ -0,0 +1,97 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace xla {
+
+class HloInstruction;
+class HloComputation;
+class HloModule;
+
+// Data structure used to track the cloning of HloInstruction and HloComputation
+// objects.
+class HloCloneContext {
+ public:
+ // Creates a new HloCloneContext object to clone HloInstruction and
+ // HloComputation objects to be added to the module specified as argument.
+ // The suffix string will be appended to computation names.
+ explicit HloCloneContext(HloModule* module, const string& suffix = "")
+ : module_(module), suffix_(suffix) {}
+
+ HloModule* module() const { return module_; }
+
+ const string& suffix() const { return suffix_; }
+
+ void MapInstruction(const HloInstruction* old_instruction,
+ HloInstruction* new_instruction) {
+ instructions_[old_instruction] = new_instruction;
+ }
+
+ void MapComputation(const HloComputation* old_computation,
+ HloComputation* new_computation) {
+ computations_[old_computation] = new_computation;
+ }
+
+ // Finds the new instruction mapped to its old copy, or return nullptr in case
+ // it is not found.
+ HloInstruction* FindInstruction(const HloInstruction* old_instruction) const {
+ return FindOrDefault(instructions_, old_instruction, nullptr);
+ }
+
+ // Finds the new computation mapped to its old copy, or return nullptr in case
+ // it is not found.
+ HloComputation* FindComputation(const HloComputation* old_computation) const {
+ return FindOrDefault(computations_, old_computation, nullptr);
+ }
+
+ // Retrieves the new instruction mapped to its old copy, or fail if not found.
+ HloInstruction* GetInstruction(const HloInstruction* old_instruction) const {
+ return FindOrDie(instructions_, old_instruction);
+ }
+
+ // Retrieves the new computation mapped to its old copy, or fail if not found.
+ HloComputation* GetComputation(const HloComputation* old_computation) const {
+ return FindOrDie(computations_, old_computation);
+ }
+
+ const tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>&
+ cloned_instructions() const {
+ return instructions_;
+ }
+
+ const tensorflow::gtl::FlatMap<const HloComputation*, HloComputation*>&
+ cloned_computations() const {
+ return computations_;
+ }
+
+ private:
+ HloModule* module_;
+ string suffix_;
+ tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>
+ instructions_;
+ tensorflow::gtl::FlatMap<const HloComputation*, HloComputation*>
+ computations_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 63c3dc4a59..b61eabbbf5 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -752,22 +752,21 @@ Status HloComputation::Accept(
}
std::unique_ptr<HloComputation> HloComputation::Clone(
- const string& suffix, HloModule* module,
- HloInstruction::CloneMap* clone_map) {
+ const string& suffix, HloCloneContext* context) {
return CloneWithReplacements(
/*replacements=*/std::unordered_map<const HloInstruction*,
std::unique_ptr<HloInstruction>>(),
- module, clone_map, suffix);
+ context, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloModule* module, HloInstruction::CloneMap* clone_map,
- const string& suffix) {
- HloInstruction::CloneMap local_clone_map;
- if (clone_map == nullptr) {
- clone_map = &local_clone_map;
+ HloCloneContext* context, const string& suffix) {
+ std::unique_ptr<HloCloneContext> context_ptr;
+ if (context == nullptr) {
+ context_ptr = MakeUnique<HloCloneContext>(parent(), suffix);
+ context = context_ptr.get();
}
// Look up instr in the replacements map, and return either the replacement,
@@ -792,18 +791,18 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
}
std::vector<std::unique_ptr<HloInstruction>> instructions;
- std::unique_ptr<HloInstruction> new_instr = nullptr;
+ std::unique_ptr<HloInstruction> new_instr;
for (auto instr : postorder) {
std::vector<HloInstruction*> new_operands;
for (auto operand : instr->operands()) {
auto replaced_operand = replace(operand);
CHECK_NE(replaced_operand, nullptr)
- << "Replacements map specifies to leave out " << operand->ToString()
- << ", but it is used by " << instr->ToString() << ".";
- new_operands.push_back(FindOrDie(*clone_map, replaced_operand));
+ << "replacements map tried to eliminate a used instruction "
+ << operand->ToString() << ", used by " << instr->ToString();
+ new_operands.push_back(context->GetInstruction(replaced_operand));
}
- new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands,
- module, clone_map);
+ new_instr =
+ instr->CloneWithNewOperands(instr->shape(), new_operands, context);
instructions.push_back(std::move(new_instr));
}
Builder builder(name() + "." + suffix);
@@ -811,22 +810,23 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
builder.AddInstruction(std::move(instr));
}
auto result = builder.Build(
- /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction())));
+ /*root_instruction=*/context->GetInstruction(
+ replace(root_instruction())));
// Clone control dependencies.
for (auto instr : postorder) {
- HloInstruction* new_instr = FindOrDie(*clone_map, instr);
+ HloInstruction* new_instr = context->GetInstruction(instr);
for (auto successor : instr->control_successors()) {
auto replaced_successor = replace(successor);
- CHECK_NE(replaced_successor, nullptr)
- << "Replacements map specifies to leave out " << successor->ToString()
- << ", but it is control-depended-on by " << instr->ToString() << ".";
-
- TF_CHECK_OK(new_instr->AddControlDependencyTo(
- FindOrDie(*clone_map, replaced_successor)));
+ // successor may not have been remapped, because it might have been
+ // removed by the replacements map.
+ if (replaced_successor != nullptr) {
+ TF_CHECK_OK(new_instr->AddControlDependencyTo(
+ context->GetInstruction(replaced_successor)));
+ }
}
}
-
+ context->MapComputation(this, result.get());
// We cloned the elements of 'replacements', so they're all going to be
// destroyed. HloInstructions need to be detached from their operands before
// they're destroyed, otherwise they stick around in the operands' users lists
@@ -836,7 +836,6 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
new_instr->DetachFromOperands();
}
}
-
return result;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 8bc97df036..0da4a305f3 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
@@ -300,17 +301,11 @@ class HloComputation {
const std::function<Status(const HloInstruction*)>& visitor_func) const;
// Returns a deep copy of this computation including all instructions.
- //
- // If the module pointer is not nullptr, then the cloned computations will be
- // added to this module in order to support deep cloning. Otherwise the module
- // of the computation is used.
- //
- // If clone_map is not nullptr, then each original instruction that is cloned
- // will be inserted and map to its clone. clone_map should not already contain
- // any of the instructions to clone.
- std::unique_ptr<HloComputation> Clone(
- const string& suffix = "clone", HloModule* module = nullptr,
- HloInstruction::CloneMap* clone_map = nullptr);
+ // If the clone context is specified, it will be populated with the cloned
+ // object mappings, and its module() will be used to add new computations
+ // into.
+ std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
+ HloCloneContext* context = nullptr);
// Like Clone(), but if an instruction is present in replacement_map, we use
// the map's value to replace that instruction in the cloned computation.
@@ -320,9 +315,7 @@ class HloComputation {
std::unique_ptr<HloComputation> CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloModule* module = nullptr,
- HloInstruction::CloneMap* clone_map = nullptr,
- const string& suffix = "clone");
+ HloCloneContext* context = nullptr, const string& suffix = "clone");
// Returns true if the given instruction can be removed from the computation.
// Parameter instructions cannot be removed without violating invariants of
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index c17c26c5a4..dab946a099 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -41,16 +42,16 @@ namespace {
// Find and combine identical constants. Constants are identical if they have
// the same type and value.
-bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) {
- bool changed = false;
-
+StatusOr<bool> CombineConstants(HloComputation* computation,
+ bool is_layout_sensitive) {
+ TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
// Map from ShortDebugString of the layoutless shape of the constant to the
// set of constant instructions with that shape. Layoutless shape is used to
// bin possible common constants together to reduce number of constant
// comparisons. If we end up having too many constant comparisons, a more
// precise binning might have to be used.
std::multimap<string, HloInstruction*> constants;
-
+ int64 combined = 0;
auto inst_it = computation->instructions().begin();
while (inst_it != computation->instructions().end()) {
HloInstruction* instruction = *inst_it;
@@ -70,7 +71,8 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) {
auto range = constants.equal_range(shape_string);
HloInstruction* match = nullptr;
for (auto it = range.first; it != range.second; ++it) {
- if (instruction->literal() == it->second->literal()) {
+ if (instruction->literal() == it->second->literal() &&
+ domain_map->InSameDomain(it->second, instruction)) {
match = it->second;
break;
}
@@ -81,12 +83,13 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) {
// Match found, replace this instruction with the one in the multimap.
TF_CHECK_OK(instruction->ReplaceAllUsesWith(match));
TF_CHECK_OK(computation->RemoveInstruction(instruction));
- changed = true;
+ ++combined;
}
}
}
-
- return changed;
+ VLOG(4) << "Combined " << combined << " constants in " << computation->name()
+ << " computation";
+ return combined > 0;
}
// An instruction is considered to be equivalent to another only if they
@@ -123,7 +126,9 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
continue;
}
- changed |= CombineConstants(computation, is_layout_sensitive_);
+ TF_ASSIGN_OR_RETURN(bool combined,
+ CombineConstants(computation, is_layout_sensitive_));
+ changed |= combined;
// HLO instructions are grouped into equivalency classes by using the
// cse_equal predicate defined above. This set holds a representative
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 9735764b69..e8c5ca347b 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -142,31 +142,46 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
// Test that constants with the same value but different type are *not*
// commoned.
auto builder = HloComputation::Builder(TestName());
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42)));
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<uint64>(42.0)));
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(42.0)));
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<double>(42.0)));
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ std::vector<HloInstruction*> constants;
+ constants.push_back(builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42))));
+ constants.push_back(builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))));
+ constants.push_back(builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<uint64>(42.0))));
+ constants.push_back(builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int64>(42.0))));
+ constants.push_back(builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<double>(42.0))));
+ constants.push_back(builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))));
// Duplicate the float constant to verify something happens.
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ constants.push_back(builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))));
+
+ const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
+ for (int64 i = 0; i < constants.size(); ++i) {
+ constants[i] = builder.AddInstruction(
+ HloInstruction::CreateConvert(shape_r0, constants[i]));
+ }
+ HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
+ shape_r0, HloOpcode::kAdd, constants[0], constants[1]));
+ for (int64 i = 2; i < constants.size(); ++i) {
+ root = builder.AddInstruction(HloInstruction::CreateBinary(
+ shape_r0, HloOpcode::kAdd, root, constants[i]));
+ }
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_EQ(7, computation->instruction_count());
+ EXPECT_EQ(20, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
- EXPECT_EQ(6, computation->instruction_count());
+ // CSE will remove both the second float(42.0f) and the corresponding
+ // convert/cast.
+ EXPECT_EQ(18, computation->instruction_count());
}
TEST_F(HloCseTest, NonscalarConstants) {
@@ -501,5 +516,25 @@ TEST_F(HloCseTest, CompareComputations) {
EXPECT_EQ(root->operand(0), root->operand(1));
}
+TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
+ // Test that constants with the same value but in different domains (disjoint
+ // in this case) are not collapsed.
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42)));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(2, computation->instruction_count());
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(2, computation->instruction_count());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
new file mode 100644
index 0000000000..78955db0da
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
@@ -0,0 +1,104 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+class HloDomainIsolator::RunContext {
+ public:
+ RunContext(HloModule* module, HloDomainIsolator* isolator)
+ : module_(module), isolator_(isolator) {}
+
+ StatusOr<bool> Run();
+
+ private:
+ // Inserts a kDomain instruction between parent and operand, in case
+ // the attribute (ie, sharding) values change between instruction and operand.
+ // Returns the newly inserted kDomain instruction, or nullptr if no kDomain
+ // instruction was necessary.
+ StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction,
+ HloInstruction* parent,
+ HloInstruction* operand);
+
+ HloModule* module_;
+ HloDomainIsolator* isolator_;
+};
+
+StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain(
+ HloInstruction* instruction, HloInstruction* parent,
+ HloInstruction* operand) {
+ HloInstruction* domain = nullptr;
+ std::unique_ptr<HloInstruction> domain_instruction =
+ isolator_->creator_(instruction, operand);
+ if (domain_instruction != nullptr) {
+ domain = operand->parent()->AddInstruction(std::move(domain_instruction));
+ TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain));
+ }
+ return domain;
+}
+
+StatusOr<bool> HloDomainIsolator::RunContext::Run() {
+ hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator");
+
+ int64 added_domains = 0;
+ for (HloComputation* computation : module_->computations()) {
+ // Walk in post order and place all the required kDomain instructions.
+ for (HloInstruction* instruction :
+ computation->MakeInstructionPostOrder()) {
+ if (instruction->opcode() == HloOpcode::kDomain) {
+ continue;
+ }
+ for (HloInstruction* operand : instruction->unique_operands()) {
+ // When applying multiple domains, we could end up stacking more than
+ // one in one edge, so here we want to build the effective
+ // (kDomain-less) instruction->operand edge.
+ HloInstruction* parent = instruction;
+ while (operand->opcode() == HloOpcode::kDomain) {
+ parent = operand;
+ operand = operand->mutable_operand(0);
+ }
+ // Check whether a kDomain is necessary between instruction and operand.
+ TF_ASSIGN_OR_RETURN(HloInstruction * domain,
+ CreateDomain(instruction, parent, operand));
+ if (domain != nullptr) {
+ VLOG(4) << "New domain: " << domain->ToString();
+ ++added_domains;
+ }
+ }
+ }
+ }
+ VLOG(3) << "Added " << added_domains << " kDomain instructions";
+ if (added_domains > 0) {
+ hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Isolator");
+ }
+ return added_domains > 0;
+}
+
+HloDomainIsolator::HloDomainIsolator(DomainCreator creator)
+ : creator_(std::move(creator)) {}
+
+StatusOr<bool> HloDomainIsolator::Run(HloModule* module) {
+ RunContext run_context(module, this);
+ return run_context.Run();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
new file mode 100644
index 0000000000..e0c5718509
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// Domain isolation is the task of placing kDomain instructions between HLO
+// instructions having different shrading. A kDomain instruction is essentially
+// used to break an HLO graph edge connecting two instructions with different
+// sharding. If a set of connected instructions have all the same sharding, no
+// kDomain instruciton will be placed.
+class HloDomainIsolator : public HloPassInterface {
+ public:
+ // Creates a new kDomain instruction for the edge between the use instruction
+ // (the first HloInstruction argument), and the operand instruction (the
+ // second HloInstruction argument).
+ // Returns nullptr in case no domain separation is necessary.
+ using DomainCreator = std::function<std::unique_ptr<HloInstruction>(
+ HloInstruction*, HloInstruction*)>;
+
+ explicit HloDomainIsolator(DomainCreator creator);
+
+ tensorflow::StringPiece name() const override { return "domain_isolator"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ class RunContext;
+
+ DomainCreator creator_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
new file mode 100644
index 0000000000..acb54c260c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -0,0 +1,168 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
+
+#include <algorithm>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+/* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
+ HloComputation* computation, string domain_kind) {
+ auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind)));
+ TF_RETURN_IF_ERROR(domain_map->Populate(computation));
+ return std::move(domain_map);
+}
+
+/* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
+ HloModule* module, string domain_kind) {
+ auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind)));
+ for (HloComputation* computation : module->computations()) {
+ TF_RETURN_IF_ERROR(domain_map->Populate(computation));
+ }
+ return std::move(domain_map);
+}
+
+bool HloDomainMap::InSameDomain(HloInstruction* instruction1,
+ HloInstruction* instruction2) const {
+ int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1);
+ int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1);
+ return domain_id1 >= 0 && domain_id1 == domain_id2;
+}
+
+Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
+ TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
+ // We only check operands, so we are sure to not process the empty domain from
+ // both sides.
+ for (HloInstruction* operand : instruction->unique_operands()) {
+ if (IsDomainInstruction(operand)) {
+ auto domain = MakeUnique<DomainMetadata::Domain>();
+ domain->enter_domains.insert(operand);
+ domain->exit_domains.insert(instruction);
+ TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
+ }
+ }
+ return Status::OK();
+}
+
+Status HloDomainMap::Populate(HloComputation* computation) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (IsDomainInstruction(instruction)) {
+ // If this is a kDomain of the kind we are currently processing, check
+ // whether this is an "empty domain".
+ TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction));
+ continue;
+ }
+ int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1);
+ if (domain_id >= 0) {
+ // We have already processed this instruction.
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<DomainMetadata::Domain> domain,
+ CreateDomain(instruction));
+ TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
+ }
+ return Status::OK();
+}
+
+Status HloDomainMap::InsertDomain(
+ std::unique_ptr<DomainMetadata::Domain> domain) {
+ int64 domain_id = instruction_domains_.size();
+ instruction_domains_.push_back(std::move(domain));
+ for (HloInstruction* instruction : instruction_domains_.back()->reach_set) {
+ instruction_to_domain_[instruction] = domain_id;
+ }
+ return Status::OK();
+}
+
+Status HloDomainMap::ExpandDomain(HloInstruction* instruction,
+ DomainMetadata::Domain* domain) const {
+ if (domain->reach_set.insert(instruction).second) {
+ // We should not be finding instructions with assigned domain here.
+ // If we assigned a domain to the instruction, it means that all the
+ // instructions reached by it, should have a domain as well.
+ int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1);
+ TF_RET_CHECK(domain_id < 0) << "Instruction " << instruction->ToString()
+ << " already has domain " << domain_id;
+ for (HloInstruction* operand : instruction->operands()) {
+ if (IsDomainInstruction(operand)) {
+ // The reach set instruction is a user of the domain instruction
+ // (the instruction sees the kDomain as operand).
+ // IOW the dataflow enters the domain through the kDomain instruction.
+ domain->enter_domains.insert(operand);
+ } else {
+ TF_RETURN_IF_ERROR(ExpandDomain(operand, domain));
+ }
+ }
+ for (HloInstruction* user : instruction->users()) {
+ if (IsDomainInstruction(user)) {
+ // The reach set instruction is an operand of the domain instruction
+ // (the instruction sees the kDomain as user).
+ // IOW the dataflow exits the domain through the kDomain instruction.
+ domain->exit_domains.insert(user);
+ } else {
+ TF_RETURN_IF_ERROR(ExpandDomain(user, domain));
+ }
+ }
+ }
+ return Status::OK();
+}
+
+StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
+ HloInstruction* instruction) const {
+ auto domain = MakeUnique<DomainMetadata::Domain>();
+ TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get()));
+ domain->instructions = MakeNonDomainInstructions(domain->reach_set);
+ return std::move(domain);
+}
+
+bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const {
+ if (instruction->opcode() != HloOpcode::kDomain) {
+ return false;
+ }
+ if (!domain_kind_.empty()) {
+ if (instruction->user_side_metadata().Kind() != domain_kind_) {
+ return false;
+ }
+ // Both user and operand side of the metadata must be of the same kind.
+ CHECK(instruction->operand_side_metadata().Kind() == domain_kind_)
+ << "Instruction " << instruction->ToString()
+ << " has mismatching metadata kinds";
+ }
+ return true;
+}
+
+/* static */ std::vector<HloInstruction*>
+HloDomainMap::MakeNonDomainInstructions(
+ const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set) {
+ std::vector<HloInstruction*> instructions;
+ instructions.reserve(instruction_set.size());
+ for (HloInstruction* instruction : instruction_set) {
+ if (instruction->opcode() != HloOpcode::kDomain) {
+ instructions.push_back(instruction);
+ }
+ }
+ std::sort(instructions.begin(), instructions.end(),
+ [](HloInstruction* a, HloInstruction* b) {
+ return a->unique_id() < b->unique_id();
+ });
+ return instructions;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
new file mode 100644
index 0000000000..e62ef763fb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+
+namespace xla {
+
+// The HloDomainMap splits a set of instructions within a module or computation,
+// into different domains, separated by kDomain instructions.
+// A domain is composed by a set of instructions which can reach each other via
+// operand/user edges, without crossing a kDomain insutrction of a given kind.
+// A domain never crosses computation boundaries.
+class HloDomainMap {
+ public:
+ // Creates a new HloDomainMap, creating all the domains within the input
+ // computation, of the given kind. If domain_kind is not empty, only the
+ // kDomain instructions of domain_kind will be considered as separators.
+ // Otherwise every kDomain instruction will be splitting domains.
+ static StatusOr<std::unique_ptr<HloDomainMap>> Create(
+ HloComputation* computation, string domain_kind);
+
+ // Creates a new HloDomainMap, creating all the domains within the input
+ // module, of the given kind. If domain_kind is not empty, only the
+ // kDomain instructions of domain_kind will be considered as separators.
+ // Otherwise every kDomain instruction will be splitting domains.
+ static StatusOr<std::unique_ptr<HloDomainMap>> Create(HloModule* module,
+ string domain_kind);
+
+ // Retrieves all the domains the input module or computation are composed by.
+ const std::vector<std::unique_ptr<DomainMetadata::Domain>>& GetDomains()
+ const {
+ return instruction_domains_;
+ }
+
+ // Checks whether two instructions are within the same domain.
+ bool InSameDomain(HloInstruction* instruction1,
+ HloInstruction* instruction2) const;
+
+ // Checks whether instruction is a kDomain instruction of the kind we are
+ // currently processing.
+ bool IsDomainInstruction(HloInstruction* instruction) const;
+
+ private:
+ HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {}
+
+ // Check if the kDomain instruction is facing (via its operand link) another
+ // kDomain instruction of the same kind, hence defining an empty domain.
+ // If that is the case, create the empty domain and call the proper
+ // normalizer.
+ Status TryProcessEmptyDomain(HloInstruction* instruction);
+
+ Status Populate(HloComputation* computation);
+
+ // Inserts the provided domain into the ones tracked by this object,
+ // creating a new domain ID.
+ Status InsertDomain(std::unique_ptr<DomainMetadata::Domain> domain);
+
+ // From the given instruction, epxands operand and user wise, the set of
+ // instructions which can be reached without crossing a kDomain instruction
+ // of the kind specified by domain_kind_.
+ // The domain data structure will be populated with all the reached
+ // instructions, and the boundaries of the domain, with the kDomain
+ // instructions encountered while expanding the reach.
+ Status ExpandDomain(HloInstruction* instruction,
+ DomainMetadata::Domain* domain) const;
+
+ // Creates a domain data structure using the ExpandDomain() API.
+ StatusOr<std::unique_ptr<DomainMetadata::Domain>> CreateDomain(
+ HloInstruction* instruction) const;
+
+ // Out of an instruction set, returns a vector of all the ones which are not
+ // a kDomain kind.
+ static std::vector<HloInstruction*> MakeNonDomainInstructions(
+ const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set);
+
+ string domain_kind_;
+ std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
+ tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_
diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
new file mode 100644
index 0000000000..9853bd39cd
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+
+namespace xla {
+
+// Cannot include hlo_instruction.h as this file is included from there.
+class HloInstruction;
+
+// The DomainMetadata represents the base class for metadata which can be
+// attached to kDomain HLO instructions.
+class DomainMetadata {
+ public:
+ // A Domain data structure captures all the information about a kDomain
+ // bounded instruction set.
+ struct Domain {
+ // The set of instructions which are reachable from each other via
+ // operand/user pathways, without crossing a kDomain instruction of a given
+ // kind. The reach_set can contain kDomain instructions of other kinds, if
+ // two domains of different kind intersect each other.
+ tensorflow::gtl::FlatSet<HloInstruction*> reach_set;
+
+ // The same instructions in reach_set, but purged from kDomain instructions.
+ std::vector<HloInstruction*> instructions;
+
+ // If we consider a graph edge as an arrow oriented from the operand to the
+ // user, the enter_domains will contain the set of kDomain instructions
+ // whose dataflow enters the reach set (domain), while the exit_domains
+ // contains the set of kDomain instructions whose dataflow exit the reach
+ // set.
+ tensorflow::gtl::FlatSet<HloInstruction*> enter_domains;
+ tensorflow::gtl::FlatSet<HloInstruction*> exit_domains;
+ };
+
+ virtual ~DomainMetadata() = default;
+
+ // Clones the metadata object.
+ virtual std::unique_ptr<DomainMetadata> Clone() const = 0;
+
+ // Returns the metadata type. A unique identifier which describes the real
+ // metadata type.
+ virtual tensorflow::StringPiece Kind() const = 0;
+
+ // Compares the metadata object with another one and returns true if the
+ // two matches.
+ virtual bool Matches(const DomainMetadata& other) const = 0;
+
+ // Returns a string representation of the metadata.
+ virtual string ToString() const = 0;
+
+ // Given a reachable set (the set of instructions which are reachable from
+ // each other via user/operand pathways, without crossing a kDomain
+ // instruciton), makes sure that all of them have metadata attributes which
+ // are coherent with this metadata object.
+ virtual Status NormalizeInstructions(const Domain& domain) const = 0;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc
new file mode 100644
index 0000000000..1d06040b0e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc
@@ -0,0 +1,149 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_domain_remover.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+class HloDomainRemover::RunContext {
+ public:
+ RunContext(HloModule* module, HloDomainRemover* remover)
+ : module_(module), remover_(remover) {}
+
+ StatusOr<bool> Run();
+
+ private:
+ // Verifies the consistency of the domain, and normalizes the instructions
+ // within it.
+ Status VerifyAndNormalizeDomain(const DomainMetadata::Domain& domain);
+
+ HloModule* module_;
+ HloDomainRemover* remover_;
+};
+
+Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain(
+ const DomainMetadata::Domain& domain) {
+ // Verify that the whole kDomain frontier bounding the instruction reach set,
+ // has matching metadata.
+ // A kDomain instruction has two sides of metadata, a user facing and an
+ // operand facing.
+ // A reachable instruction set can make contact with a kDomain instruction on
+ // a user facing side (the kDomain is operand of the instruction), or on a
+ // operand facing side (the kDomain is user of the instruction).
+ // And depending on the contact side, the proper metadata object
+ // (user_side_metadata() vs. operand_side_metadata()) needs to be used for
+ // consistency checks.
+ const DomainMetadata* ref_metadata = nullptr;
+ VLOG(4) << "Reach set:";
+ for (HloInstruction* instruction : domain.instructions) {
+ VLOG(4) << " " << instruction->name();
+ }
+ VLOG(4) << " Domains:";
+ for (HloInstruction* instruction : domain.enter_domains) {
+ const DomainMetadata& meta = instruction->user_side_metadata();
+ VLOG(4) << " User side: " << instruction->name();
+ VLOG(4) << " " << meta.ToString();
+ if (ref_metadata == nullptr) {
+ ref_metadata = &meta;
+ } else {
+ TF_RET_CHECK(meta.Matches(*ref_metadata))
+ << "Metadata mismatch at instruction " << instruction->name() << " : "
+ << meta.ToString() << " vs " << ref_metadata->ToString();
+ }
+ }
+ for (HloInstruction* instruction : domain.exit_domains) {
+ const DomainMetadata& meta = instruction->operand_side_metadata();
+ VLOG(4) << " Operand side: " << instruction->name();
+ VLOG(4) << " " << meta.ToString();
+ if (ref_metadata == nullptr) {
+ ref_metadata = &meta;
+ } else {
+ TF_RET_CHECK(meta.Matches(*ref_metadata))
+ << "Metadata mismatch at instruction " << instruction->name() << " : "
+ << meta.ToString() << " vs " << ref_metadata->ToString();
+ }
+ }
+ if (ref_metadata != nullptr) {
+ VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString();
+ TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain));
+ } else {
+ // No kDomain instruction was present within this domain, so call the
+ // generic normalization functions and have them apply their heuristic.
+ VLOG(2) << "Applying domain-less normalization";
+ TF_RETURN_IF_ERROR(remover_->normalizer_(domain));
+ }
+ return Status::OK();
+}
+
+StatusOr<bool> HloDomainRemover::RunContext::Run() {
+ VLOG(4) << "Processing metadata domain: '" << remover_->kind_ << "'";
+ hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Remover");
+
+ int64 removed_domains = 0;
+ for (HloComputation* computation : module_->computations()) {
+ // First create the domain instruciton sets. A domain instruction set is
+ // the set of instructions whose edges never cross a kDomain instruction.
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDomainMap> domain_map,
+ HloDomainMap::Create(computation, remover_->kind_));
+ // Verify and normalize every domain populated within the map.
+ for (auto& domain : domain_map->GetDomains()) {
+ TF_RETURN_IF_ERROR(VerifyAndNormalizeDomain(*domain));
+ }
+
+ // Now remove all the kDomain instructions of the kind specified by the
+ // remover, that are within the currently processed computation from the
+ // graph.
+ for (HloInstruction* instruction :
+ computation->MakeInstructionPostOrder()) {
+ for (HloInstruction* operand : instruction->unique_operands()) {
+ if (domain_map->IsDomainInstruction(operand)) {
+ VLOG(5) << "Removing " << operand->name();
+ TF_RETURN_IF_ERROR(
+ operand->ReplaceAllUsesWith(operand->mutable_operand(0)));
+ TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand));
+ ++removed_domains;
+ }
+ }
+ }
+ HloInstruction* root = computation->root_instruction();
+ if (root != nullptr && domain_map->IsDomainInstruction(root)) {
+ VLOG(5) << "Removing " << root->name();
+ computation->set_root_instruction(root->mutable_operand(0));
+ TF_RETURN_IF_ERROR(computation->RemoveInstruction(root));
+ ++removed_domains;
+ }
+ }
+ VLOG(3) << "Removed " << removed_domains << " kDomain instructions of '"
+ << remover_->kind_ << "' kind";
+ if (removed_domains > 0) {
+ hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Remover");
+ }
+ return removed_domains > 0;
+}
+
+StatusOr<bool> HloDomainRemover::Run(HloModule* module) {
+ RunContext run_context(module, this);
+ return run_context.Run();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h
new file mode 100644
index 0000000000..0c71dd34fd
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+
+// Removes all the kDomain instructions of a given kind from the input module,
+// and calls the normalizer to propagate the properties on the possibly new born
+// instructions.
+class HloDomainRemover : public HloPassInterface {
+ public:
+ // Creates a new HloDomainRemover object tasked at removing all the kDomain
+ // instructions of a given kind.
+ // In case a reachable set (the set of instructions within a computation,
+ // which are mutually reachable via operand/user pathways) has all the
+ // instructions in it with the same attributes (ie, sharding), a normalizer
+ // function is tasked at applying attribute normalization on the instructions
+ // within such domain.
+ HloDomainRemover(
+ tensorflow::StringPiece kind,
+ std::function<Status(const DomainMetadata::Domain&)> normalizer)
+ : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {}
+
+ tensorflow::StringPiece name() const override { return "domain_remover"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ class RunContext;
+
+ string kind_;
+ std::function<Status(const DomainMetadata::Domain&)> normalizer_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
new file mode 100644
index 0000000000..f29aac29c0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -0,0 +1,432 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_remover.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloDomainTest : public HloTestBase {
+ protected:
+ bool FindUserViaDomainPath(HloInstruction* instruction,
+ HloInstruction* operand) const {
+ for (HloInstruction* user : operand->users()) {
+ if (user == instruction) {
+ return true;
+ }
+ if (user->opcode() == HloOpcode::kDomain &&
+ FindUserViaDomainPath(instruction, user)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ // Checks whether there is a kDomain instruction in the edge between the
+ // instruction and the operand.
+ bool HasDomainEdge(HloModule* module,
+ tensorflow::StringPiece instruction_name,
+ tensorflow::StringPiece operand_name) {
+ HloInstruction* instruction = FindInstruction(module, instruction_name);
+ HloInstruction* operand = FindInstruction(module, operand_name);
+ CHECK_NE(instruction, nullptr);
+ CHECK_NE(operand, nullptr);
+ if (!instruction->IsUserOf(operand)) {
+ // If instruction is not an immediate user, we must find a path from
+ // operand to instruction anyway, otherwise there is a corruption.
+ if (FindUserViaDomainPath(instruction, operand)) {
+ return true;
+ }
+ LOG(FATAL) << "Bad HLO module generated across the '" << instruction_name
+ << "' and '" << operand_name << "' instructions:\n"
+ << module->ToString();
+ }
+ return false;
+ }
+
+ StatusOr<std::unique_ptr<HloModule>> ParseModule(
+ tensorflow::StringPiece hlo_string) {
+ HloModuleConfig config;
+ config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
+ return tools::Parse(hlo_string, config);
+ }
+};
+
+// Dummy DomainMetadata implementation which create kDomain boundaries around
+// HLO instructions with the same metadata().op_name() values.
+class OpNameMetadata : public DomainMetadata {
+ public:
+ explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {}
+
+ std::unique_ptr<DomainMetadata> Clone() const override {
+ return MakeUnique<OpNameMetadata>(opname_);
+ }
+
+ tensorflow::StringPiece Kind() const override { return KindName(); }
+
+ bool Matches(const DomainMetadata& other) const override {
+ const OpNameMetadata* other_ptr =
+ dynamic_cast<const OpNameMetadata*>(&other);
+ if (other_ptr == nullptr) {
+ // If other is not a OpNameMetadata, then it is clearly a no match.
+ return false;
+ }
+ return opname_ == other_ptr->opname_;
+ }
+
+ string ToString() const override { return opname_; }
+
+ Status NormalizeInstructions(
+ const DomainMetadata::Domain& domain) const override {
+ // For the purposes of this test, nothing to do.
+ return Status::OK();
+ }
+
+ static tensorflow::StringPiece KindName() { return "opname"; }
+
+ private:
+ string opname_;
+};
+
+// Creator function for OpNameMetadata domains.
+std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction,
+ HloInstruction* operand) {
+ if (instruction->metadata().op_name() == operand->metadata().op_name()) {
+ return nullptr;
+ }
+ std::unique_ptr<DomainMetadata> operand_side_metadata =
+ MakeUnique<OpNameMetadata>(operand->metadata().op_name());
+ std::unique_ptr<DomainMetadata> user_side_metadata =
+ MakeUnique<OpNameMetadata>(instruction->metadata().op_name());
+ return HloInstruction::CreateDomain(operand->shape(), operand,
+ std::move(operand_side_metadata),
+ std::move(user_side_metadata));
+}
+
+Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain) {
+ // Nothing to do for the particular use this test make of the OpName domains.
+ return Status::OK();
+}
+
+TEST_F(HloDomainTest, CheckDomainLinks) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ p0 = (f32[4], f32[4]) parameter(0)
+ a = f32[4] get-tuple-element(p0), index=0
+ b = f32[4] get-tuple-element(p0), index=1
+ c = f32[4] add(f32[4] a, f32[4] b), sharding={maximal device=1}
+ d = f32[4] subtract(a, b), sharding={maximal device=1}
+ e = f32[4] multiply(c, d), sharding={maximal device=1}
+ ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseModule(hlo_string));
+ LOG(INFO) << "Original module:\n" << module->ToString();
+
+ HloDomainIsolator isolator(CreateShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
+ EXPECT_TRUE(isolator_changed);
+
+ EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a"));
+ EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b"));
+ EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a"));
+ EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
+
+ HloDomainRemover remover(ShardingMetadata::KindName(),
+ NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
+ EXPECT_TRUE(remover_changed);
+
+ EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
+}
+
+TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ p0 = (f32[4], f32[4]) parameter(0)
+ a = f32[4] get-tuple-element(p0), index=0
+ b = f32[4] get-tuple-element(p0), index=1
+ c = f32[4] add(f32[4] a, f32[4] b)
+ d = f32[4] subtract(a, b)
+ e = f32[4] multiply(c, d)
+ ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseModule(hlo_string));
+ LOG(INFO) << "Original module:\n" << module->ToString();
+
+ HloDomainIsolator isolator(CreateShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
+ EXPECT_TRUE(!isolator_changed);
+}
+
+TEST_F(HloDomainTest, CheckDomainAroundIO) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ p0 = (f32[4]) parameter(0)
+ a = f32[4] get-tuple-element(p0), index=0
+ b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0}
+ c = () send-done(b), channel_id=1, sharding={maximal device=0}
+ d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0}
+ e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0}
+ f = f32[4] add(a, e)
+ g = f32[4] subtract(a, e)
+ ROOT h = (f32[4], f32[4]) tuple(f, g)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseModule(hlo_string));
+ LOG(INFO) << "Original module:\n" << module->ToString();
+
+ HloDomainIsolator isolator(CreateShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
+ EXPECT_TRUE(isolator_changed);
+
+ EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a"));
+ EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
+
+ HloDomainRemover remover(ShardingMetadata::KindName(),
+ NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
+ EXPECT_TRUE(remover_changed);
+
+ EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e"));
+}
+
+TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1}
+ b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1}
+ c = f32[4] add(b, b), sharding={maximal device=-1}
+ d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1}
+ ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseModule(hlo_string));
+ LOG(INFO) << "Original module:\n" << module->ToString();
+
+ HloDomainIsolator isolator(CreateShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
+ EXPECT_FALSE(isolator_changed);
+}
+
+TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0}
+ b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0}
+ c = f32[4] add(b, b)
+ d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0}
+ ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseModule(hlo_string));
+ LOG(INFO) << "Original module:\n" << module->ToString();
+
+ HloDomainRemover remover(ShardingMetadata::KindName(),
+ NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
+ EXPECT_FALSE(remover_changed);
+
+ HloInstruction* add = FindInstruction(module.get(), "c");
+ ASSERT_NE(add, nullptr);
+ auto device = add->sharding_unique_device();
+ EXPECT_TRUE(device.has_value());
+ EXPECT_EQ(*device, 0);
+}
+
+TEST_F(HloDomainTest, CheckMultiDomainLinks) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ p0 = (f32[4], f32[4]) parameter(0)
+ a = f32[4] get-tuple-element(p0), index=0
+ b = f32[4] get-tuple-element(p0), index=1
+ c = f32[4] add(a, b), sharding={maximal device=1}
+ d = f32[4] subtract(a, c), sharding={maximal device=1}, metadata={op_name="D"}
+ e = f32[4] multiply(c, d), sharding={maximal device=1}, metadata={op_name="D"}
+ f = f32[4] add(e, c), sharding={maximal device=1}
+ ROOT g = (f32[4], f32[4], f32[4]) tuple(c, d, f)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseModule(hlo_string));
+ LOG(INFO) << "Original module:\n" << module->ToString();
+
+ HloDomainIsolator sharding_isolator(CreateShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed,
+ sharding_isolator.Run(module.get()));
+ EXPECT_TRUE(sharding_isolator_changed);
+
+ HloDomainIsolator opname_isolator(OpNameDomainCreator);
+ TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
+ opname_isolator.Run(module.get()));
+ EXPECT_TRUE(opname_isolator_changed);
+
+ EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a"));
+ EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b"));
+ EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a"));
+ EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
+
+ HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
+ NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
+ sharding_remover.Run(module.get()));
+ EXPECT_TRUE(sharding_remover_changed);
+
+ HloDomainRemover opname_remover(OpNameMetadata::KindName(),
+ OpNameDomainNormalizer);
+ TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed,
+ opname_remover.Run(module.get()));
+ EXPECT_TRUE(opname_remover_changed);
+
+ EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c"));
+}
+
+TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ infeed = (f32[4], f32[4]) infeed(),
+ sharding={{maximal device=1}, {maximal device=0}}
+ gte0 = f32[4] get-tuple-element(infeed), index=0
+ gte1 = f32[4] get-tuple-element(infeed), index=1
+ copy0 = f32[4] copy(gte0)
+ copy1 = f32[4] copy(gte1)
+ ROOT add = f32[4] add(copy0, copy1)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseModule(hlo_string));
+ LOG(INFO) << "Original module:\n" << module->ToString();
+
+ HloDomainIsolator isolator(CreateShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
+ EXPECT_TRUE(isolator_changed);
+
+ EXPECT_TRUE(HasDomainEdge(module.get(), "gte0", "infeed"));
+ EXPECT_TRUE(HasDomainEdge(module.get(), "gte1", "infeed"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0"));
+ EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1"));
+
+ // Inject unassigned tuple/gte within the infeed domain, to simulate the
+ // HLO passes adding unexpected instructions.
+ //
+ // infeed
+ // / \
+ // GTE0 GTE1
+ // / \
+ // COPY0 COPY1
+ // \ /
+ // \ /
+ // TUPLE
+ // |
+ // DOMAIN
+ HloInstruction* infeed = FindInstruction(module.get(), "infeed");
+ ASSERT_NE(infeed, nullptr);
+ auto infeed_users = infeed->users();
+ HloInstruction* new_gte0 =
+ infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0));
+ HloInstruction* new_copy0 =
+ infeed->parent()->AddInstruction(HloInstruction::CreateUnary(
+ new_gte0->shape(), HloOpcode::kCopy, new_gte0));
+ HloInstruction* new_gte1 =
+ infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1));
+ HloInstruction* new_copy1 =
+ infeed->parent()->AddInstruction(HloInstruction::CreateUnary(
+ new_gte1->shape(), HloOpcode::kCopy, new_gte1));
+ HloInstruction* new_tuple = infeed->parent()->AddInstruction(
+ HloInstruction::CreateTuple({new_copy0, new_copy1}));
+ for (HloInstruction* user : infeed_users) {
+ TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple));
+ }
+
+ HloDomainRemover remover(ShardingMetadata::KindName(),
+ NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
+ EXPECT_TRUE(remover_changed);
+
+ struct Assignment {
+ HloInstruction* instruction;
+ int64 device;
+ } assignments[] = {
+ {new_gte0, 1},
+ {new_copy0, 1},
+ {new_gte1, 0},
+ {new_copy1, 0},
+ };
+ for (auto& assignment : assignments) {
+ auto device = assignment.instruction->sharding_unique_device();
+ EXPECT_TRUE(device.has_value());
+ EXPECT_EQ(*device, assignment.device);
+ }
+ EXPECT_TRUE(new_tuple->has_sharding());
+ EXPECT_EQ(
+ new_tuple->sharding(),
+ HloSharding::Tuple(new_tuple->shape(), {HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(0)}));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
index d236f83aeb..abec29df43 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -119,6 +119,7 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
return false;
}
+ HloCloneContext context(module);
bool changed = false;
for (auto* computation : module->computations()) {
for (auto* hlo : computation->MakeInstructionPostOrder()) {
@@ -180,7 +181,7 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_);
new_hlo = computation->AddInstruction(
- hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule()));
+ hlo->CloneWithNewOperands(shape, new_operands, &context));
TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
new_hlo = ToElementType(new_hlo, eliminate_type_);
@@ -189,16 +190,16 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_,
replace_with_type_);
- new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands(
- new_shape, new_operands, hlo->GetModule()));
+ new_hlo = computation->AddInstruction(
+ hlo->CloneWithNewOperands(new_shape, new_operands, &context));
TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
// Convert the elements of the result of `new_hlo` to produce a new
// tuple with shape `old_shape`.
new_hlo = ConvertTupleElements(new_hlo, old_shape);
} else {
- new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands(
- hlo->shape(), new_operands, hlo->GetModule()));
+ new_hlo = computation->AddInstruction(
+ hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context));
TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index e90eb0669d..1e78d775c8 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -965,9 +965,10 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
// Attach cloned computation to an empty HLO module so the existing ones are
// not modified.
HloModule empty_hlo_module("EmptyModuleForFusion", config);
+ HloCloneContext context(&empty_hlo_module);
auto cloned_fused_computation =
fusion->fused_instructions_computation()->Clone(
- /*suffix=*/"clone_with_layout", &empty_hlo_module);
+ /*suffix=*/"clone_with_layout", &context);
for (auto* instruction : cloned_fused_computation->instructions()) {
LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
}
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index efdeb6c64f..672b1c017a 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1010,6 +1010,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
return kPurple;
+ case HloOpcode::kDomain:
case HloOpcode::kFusion:
case HloOpcode::kMap:
return kGray;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index db1c33e2f0..dc351e9968 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -256,6 +257,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kClz:
+ case HloOpcode::kDomain:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
@@ -821,6 +823,15 @@ HloInstruction::CreateBroadcastSequence(
return instruction;
}
+void HloInstruction::set_device_sharding(int64 device) {
+ HloSharding device_sharding = HloSharding::AssignDevice(device);
+ if (ShapeUtil::IsTuple(shape())) {
+ set_sharding(HloSharding::Tuple(device_sharding.GetAsShapeTree(shape())));
+ } else {
+ set_sharding(device_sharding);
+ }
+}
+
void HloInstruction::SetupDerivedInstruction(
HloInstruction* derived_instruction) const {
if (sharding_ != nullptr) {
@@ -1225,21 +1236,28 @@ bool HloInstruction::HasSideEffect() const {
return gather_dim_numbers;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape));
+ instruction->operand_side_metadata_ = std::move(operand_side_metadata);
+ instruction->user_side_metadata_ = std::move(user_side_metadata);
+ instruction->AppendOperand(operand);
+ return instruction;
+}
+
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
- HloModule* module, CloneMap* clone_map) const {
+ HloCloneContext* context) const {
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
for (const HloInstruction* new_operand : new_operands) {
VLOG(3) << " %" << new_operand->name();
}
- if (module == nullptr) {
- module = GetModule();
- }
std::unique_ptr<HloInstruction> clone;
-
// Explicitly call the factory for the instruction type. This is more robust
// in the face of code changes than copying fields explicitly. This also
// properly sets the user fields of the operands.
@@ -1419,9 +1437,16 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateConstant(literal_->CloneToUnique());
break;
case HloOpcode::kFusion: {
- CHECK_NE(module, nullptr);
- auto new_fused_computation = module->AddEmbeddedComputation(
- fused_instructions_computation()->Clone("clone", module, clone_map));
+ HloModule* module = context != nullptr ? context->module() : GetModule();
+ HloComputation* new_fused_computation = nullptr;
+ if (context != nullptr) {
+ new_fused_computation =
+ context->FindComputation(fused_instructions_computation());
+ }
+ if (new_fused_computation == nullptr) {
+ new_fused_computation = module->AddEmbeddedComputation(
+ fused_instructions_computation()->Clone("clone", context));
+ }
clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(),
/*operands=*/new_operands,
/*fusion_computation=*/new_fused_computation);
@@ -1485,14 +1510,25 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateGather(shape, new_operands[0], new_operands[1],
*gather_dimension_numbers_, gather_window_bounds_);
break;
+ case HloOpcode::kDomain:
+ CHECK_EQ(new_operands.size(), 1);
+ clone =
+ CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(),
+ user_side_metadata_->Clone());
+ break;
case HloOpcode::kTrace:
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
}
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
clone->set_backend_config(backend_config());
- if (clone_map != nullptr) {
- InsertOrDie(clone_map, this, clone.get());
+ if (context != nullptr) {
+ context->MapInstruction(this, clone.get());
+ clone->ReplaceCalledComputations([&](HloComputation* callee) {
+ return callee->parent() != context->module()
+ ? context->module()->DeepCloneComputation(callee, context)
+ : callee;
+ });
}
return clone;
}
@@ -1500,9 +1536,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
HloInstruction::~HloInstruction() {}
std::unique_ptr<HloInstruction> HloInstruction::Clone(
- const string& suffix, HloModule* module, CloneMap* clone_map) const {
+ const string& suffix, HloCloneContext* context) const {
std::unique_ptr<HloInstruction> clone =
- CloneWithNewOperands(shape_, operands_, module, clone_map);
+ CloneWithNewOperands(shape_, operands_, context);
if (suffix.empty()) {
clone->name_ = name();
} else {
@@ -1614,6 +1650,17 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const {
LOG(FATAL) << "target was not an operand: " << target->ToString();
}
+HloInstruction::InstructionVector HloInstruction::unique_operands() const {
+ InstructionVector unique;
+ tensorflow::gtl::FlatSet<const HloInstruction*> seen;
+ for (HloInstruction* operand : operands()) {
+ if (seen.insert(operand).second) {
+ unique.push_back(operand);
+ }
+ }
+ return unique;
+}
+
Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
TF_RET_CHECK(instruction->parent() == parent());
if (std::find(control_successors_.begin(), control_successors_.end(),
@@ -1758,6 +1805,7 @@ bool HloInstruction::IdenticalSlowPath(
other.fused_instructions_computation());
// These opcodes have complex or special behavior so just return false.
+ case HloOpcode::kDomain:
case HloOpcode::kRng:
case HloOpcode::kTrace:
case HloOpcode::kWhile:
@@ -2369,7 +2417,13 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(StrCat("exponent_bits=", exponent_bits_));
extra.push_back(StrCat("mantissa_bits=", mantissa_bits_));
}
-
+ if (operand_side_metadata_ != nullptr) {
+ extra.push_back(
+ StrCat("operand_side=", operand_side_metadata_->ToString()));
+ }
+ if (user_side_metadata_ != nullptr) {
+ extra.push_back(StrCat("user_side=", user_side_metadata_->ToString()));
+ }
// By contract, we print the custom call target even if
// options.print_subcomputation_mode() == kOff, because the call target is not
// an HloComputation.
@@ -2546,6 +2600,7 @@ bool HloInstruction::IsFusable() const {
}
// Some kinds of instructions don't make sense to fuse.
switch (opcode_) {
+ case HloOpcode::kDomain:
case HloOpcode::kParameter:
return false;
// Side effecting instrutions cannot be fused.
@@ -2558,7 +2613,9 @@ HloComputation* HloInstruction::fused_instructions_computation() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(!called_computations_.empty());
auto* fused_instructions_computation = called_computations_.front();
- CHECK(fused_instructions_computation->IsFusionComputation());
+ CHECK(fused_instructions_computation->IsFusionComputation())
+ << "Computation " << fused_instructions_computation->name()
+ << " is not a fusion kind";
return fused_instructions_computation;
}
@@ -2773,6 +2830,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSendDone(this);
case HloOpcode::kGather:
return visitor->HandleGather(this);
+ case HloOpcode::kDomain:
+ return visitor->HandleDomain(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 234dbc8399..6df97c40ba 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -37,6 +37,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
@@ -597,6 +599,13 @@ class HloInstruction {
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+ // Creates a kDomain instruction which delimits an HLO domain which have
+ // the provided user and operand side metadata.
+ static std::unique_ptr<HloInstruction> CreateDomain(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata);
+
// Creates a fusion instruction. A fusion instruction contains one or more
// fused instructions forming an expression with a single root
// "fused_root". Additional instructions can be added to the fusion
@@ -676,6 +685,10 @@ class HloInstruction {
using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
const InstructionVector& operands() const { return operands_; }
+ // Returns the vector of unique operands, in the same order they are found
+ // within the operand vector.
+ InstructionVector unique_operands() const;
+
// Returns the index of 'target' in the operands sequence.
// Precondition: target must be an operand (or a fatal error will occur).
int64 operand_index(const HloInstruction* target) const;
@@ -1094,16 +1107,20 @@ class HloInstruction {
}
// Returns the sharding unique device, if any.
tensorflow::gtl::optional<int64> sharding_unique_device() const {
- if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) {
+ if (sharding_ == nullptr) {
return tensorflow::gtl::optional<int64>();
}
- return sharding_->UniqueDevice().ValueOrDie();
+ auto device = sharding_->UniqueDevice();
+ return device.ok() ? device.ValueOrDie()
+ : tensorflow::gtl::optional<int64>();
}
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.
void set_sharding(const HloSharding& sharding) {
sharding_ = MakeUnique<HloSharding>(sharding);
}
+ // Sets a sharding that assigns the current instruction to device.
+ void set_device_sharding(int64 device);
// Remove any sharding from this operator.
void clear_sharding() { sharding_ = nullptr; }
// Return true if this operator has a sharding assigned.
@@ -1117,6 +1134,15 @@ class HloInstruction {
return other->has_sharding() ? sharding() == other->sharding() : false;
}
+ // Retrieves the operand side metadata of a kDomain instruction.
+ const DomainMetadata& operand_side_metadata() const {
+ return *operand_side_metadata_;
+ }
+ // Retrieves the user side metadata of a kDomain instruction.
+ const DomainMetadata& user_side_metadata() const {
+ return *user_side_metadata_;
+ }
+
// When creating a new instruction which either replaces, or shifts up (kCopy
// insertion case), another instruction, we need to make sure the certain
// properties of the new instruction are copied into the derived one. As of
@@ -1317,30 +1343,18 @@ class HloInstruction {
// Precondition: opcode() == HloOpcode::kRng
RandomDistribution random_distribution() const;
- // See documentation for Clone().
- using CloneMap = std::unordered_map<const HloInstruction*, HloInstruction*>;
-
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
- // the instruction to form the name of the cloned instruction. Ignores the
- // control predecessors and successors of this HLO instruction.
- //
- // If the module pointer is not nullptr, then any cloned computations will be
- // added to this module in order to support deep cloning. Otherwise the module
- // of the instruction is used.
- //
- // If clone_map is not nullptr, then each original instruction that is cloned
- // will be inserted and map to its clone. clone_map should not already contain
- // any of the instructions to clone.
- std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone",
- HloModule* module = nullptr,
- CloneMap* clone_map = nullptr) const;
+ // the instruction to form the name of the cloned instruction.
+ // Ignores the control predecessors and successors of this HLO instruction.
+ std::unique_ptr<HloInstruction> Clone(
+ const string& suffix = "clone", HloCloneContext* context = nullptr) const;
// Clones the HLO instruction as above but with new shape and operands.
std::unique_ptr<HloInstruction> CloneWithNewOperands(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloModule* module = nullptr, CloneMap* clone_map = nullptr) const;
+ HloCloneContext* context = nullptr) const;
// Returns the computations this instruction directly calls (if any).
const std::vector<HloComputation*>& called_computations() const {
@@ -1553,7 +1567,7 @@ class HloInstruction {
// Clones a fusion instruction with a new shape and operands.
std::unique_ptr<HloInstruction> CloneFusionWithNewOperands(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloModule* module = nullptr) const;
+ HloCloneContext* context = nullptr) const;
// Returns true if this instruction can legally have the dimensions field
// set. Used for checking precondition of dimensions field accessors.
@@ -1646,6 +1660,10 @@ class HloInstruction {
// The sharding, if one exists.
std::unique_ptr<HloSharding> sharding_;
+ // Fields used by the kDomain instruction.
+ std::unique_ptr<DomainMetadata> operand_side_metadata_;
+ std::unique_ptr<DomainMetadata> user_side_metadata_;
+
// For parameter instructions this field holds the parameter number.
int64 parameter_number_ = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index a61c472c72..e91cf2076f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -1494,5 +1495,52 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
})");
}
+TEST_F(HloInstructionTest, CheckDeepClone) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+addy (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT zadd = s32[] add(lhs, rhs)
+}
+
+calla (x: s32[]) -> s32[] {
+ x = s32[] parameter(0)
+ reduce = s32[] reduce-window(x, x), to_apply=addy
+ ROOT xadd = s32[] add(x, reduce)
+}
+
+body (bparam: s32[]) -> s32[] {
+ constant = s32[] constant(1)
+ bparam = s32[] parameter(0)
+ v = s32[] call(bparam), to_apply=calla
+ ROOT add = s32[] add(constant, bparam)
+}
+
+condition (cparam: s32[]) -> pred[] {
+ xconstant = s32[] constant(5)
+ cparam = s32[] parameter(0)
+ ROOT greater-than = pred[] greater-than(xconstant, cparam)
+}
+
+ENTRY entry (param: s32[]) -> s32[] {
+ eparam = s32[] parameter(0)
+ ROOT while = s32[] while(eparam), condition=condition, body=body
+ }
+)";
+ // Check that deep clones really deep clones every instruction and
+ // computations, without leaving dangling pointers to the old module.
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+ std::unique_ptr<HloModule> clone = module->Clone();
+ for (HloComputation* computation : clone->computations()) {
+ EXPECT_EQ(computation->parent(), clone.get());
+ for (HloInstruction* instruction : computation->instructions()) {
+ EXPECT_EQ(instruction->parent()->parent(), clone.get());
+ }
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index fbf1d58007..e63424c2df 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -496,7 +496,18 @@ std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
added_computations.insert(computation.get());
}
}
- CHECK_EQ(post_order.size(), computations_.size());
+ if (post_order.size() != computations_.size()) {
+ for (HloComputation* computation : post_order) {
+ LOG(ERROR) << "Post Order: " << computation->name() << " ("
+ << computation->parent()->name() << ")";
+ }
+ for (auto& computation : computations_) {
+ LOG(ERROR) << "Computations: " << computation->name() << " ("
+ << computation->parent()->name() << ")";
+ }
+ LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
+ << " computation_count=" << computations_.size();
+ }
return post_order;
}
@@ -517,54 +528,25 @@ std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
module->entry_computation_handle_ = entry_computation_handle_;
module->has_entry_computation_handle_ = has_entry_computation_handle_;
- std::unordered_map<HloComputation*, HloComputation*> clone_map;
- for (auto& computation : computations_) {
- if (computation->IsFusionComputation()) {
- // Cloning of a fused computation is handled by its fusion instruction.
- continue;
- }
-
- // When cloning a computation, pass in the new module, so that for any
- // fusion instruction in this computation, the fused computation will be
- // deep cloned to the new module.
- auto cloned_computation = computation->Clone(suffix, module.get());
- InsertOrDie(&clone_map, computation.get(), cloned_computation.get());
-
- if (entry_computation_ == computation.get()) {
- module->AddEntryComputation(std::move(cloned_computation));
- } else {
- module->AddEmbeddedComputation(std::move(cloned_computation));
- }
- }
-
- for (auto& cloned_computation : module->computations_) {
- for (auto* instruction : cloned_computation->instructions()) {
- // Rewrite instruction's called_computation to point to the cloned
- // computations.
- instruction->ReplaceCalledComputations([&](HloComputation* hlo) {
- if (hlo->IsFusionComputation()) {
- // Cloning of a fused computation has already been handled when its
- // fusion instruction is cloned. So this hlo computation is already
- // the cloned one.
- return hlo;
- }
- return FindOrDie(clone_map, hlo);
- });
- }
- }
+ HloCloneContext context(module.get(), suffix);
+ auto cloned_computation = entry_computation_->Clone(suffix, &context);
+ module->AddEntryComputation(std::move(cloned_computation));
return module;
}
-HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) {
- HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this));
- TF_CHECK_OK(
- clone->root_instruction()->Accept([this](HloInstruction* instruction) {
- instruction->ReplaceCalledComputations([this](HloComputation* callee) {
- return DeepCloneComputation(callee);
- });
- return Status::OK();
- }));
- return clone;
+HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
+ HloCloneContext* context) {
+ HloComputation* new_computation;
+ if (context != nullptr) {
+ if ((new_computation = context->FindComputation(computation)) != nullptr) {
+ return new_computation;
+ }
+ new_computation =
+ AddEmbeddedComputation(computation->Clone(context->suffix(), context));
+ } else {
+ new_computation = AddEmbeddedComputation(computation->Clone(""));
+ }
+ return new_computation;
}
uint64 HloModule::RandomNew64() const {
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 02918c3777..c93c74d34a 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
@@ -94,8 +95,10 @@ class HloModule {
std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
// Performs a deep clone of the computation, by recursively cloning all
- // the called computations as well.
- HloComputation* DeepCloneComputation(HloComputation* computation);
+ // the called computations as well. If the clone context is specified, it
+ // will be populated with the cloned object mappings.
+ HloComputation* DeepCloneComputation(HloComputation* computation,
+ HloCloneContext* context = nullptr);
// Return a pointer to the entry computation of the module..
const HloComputation* entry_computation() const {
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index b4cd3c730e..7d706b5fd0 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -87,6 +87,7 @@ Status HloModuleGroupMetadata::Build() {
<< "Peer instruction does not match the computation kind";
TF_RETURN_IF_ERROR(
AddCompanion(tracked->instruction(), peer_tracked->instruction()));
+ tracked_instructions_comms_[tracked->instruction()].push_back(hlo);
}
// Add the parents of companion instructions (they must be all of the same
@@ -116,23 +117,31 @@ Status HloModuleGroupMetadata::Build() {
}
Status HloModuleGroupMetadata::VerifyCompanionSets() const {
- // TODO(dlibenzi): Migrate this to use the device instead of module ID, once
- // the kDomain CL goes in.
for (const auto& companions : companion_sets_) {
// A companion set must be composed at most of an instruction per
// device/module.
std::unordered_set<int64> devices;
for (HloInstruction* instruction : *companions) {
- int64 device = GetModuleId(instruction->parent()->parent());
- if (!devices.insert(device).second) {
- std::stringstream ss;
- ss << "Companion set:" << std::endl;
- for (HloInstruction* hlo : *companions) {
- ss << " " << hlo->name() << " ("
- << GetModuleId(hlo->parent()->parent()) << ")" << std::endl;
+ // Go through all the communicating instructions (send, recv) of the given
+ // companion, and record their device.
+ std::unordered_set<int64> comm_devices;
+ for (HloInstruction* comm_instruction :
+ tracked_instructions_comms_.at(instruction)) {
+ auto device = GetInstructionDevice(*comm_instruction);
+ TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString()
+ << " does not have a device";
+ comm_devices.insert(*device);
+ }
+ for (int64 device : comm_devices) {
+ if (!devices.insert(device).second) {
+ std::stringstream ss;
+ ss << "Companion set:" << std::endl;
+ for (HloInstruction* hlo : *companions) {
+ ss << " " << hlo->name() << std::endl;
+ }
+ ss << "has multiple instructions on the same device";
+ return FailedPrecondition("%s", ss.str().c_str());
}
- ss << "has multiple instructions on the same device";
- return FailedPrecondition("%s", ss.str().c_str());
}
}
}
@@ -223,6 +232,21 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
LOG(FATAL) << "unknown module";
}
+tensorflow::gtl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
+ const HloInstruction& instruction) const {
+ // The module group metadata can be created in both "single module, multiple
+ // devices" and "multiple modules, no explicit devices" fashions.
+ // The API returns an optional even though the current implementation always
+ // returns a device, to account for cases where we cannot guess a device.
+ // In such cases the VerifyChannelInstructions() will return proper errors.
+ tensorflow::gtl::optional<int64> device =
+ instruction.sharding_unique_device();
+ if (!device) {
+ device = GetModuleId(instruction.parent()->parent());
+ }
+ return device;
+}
+
Status HloModuleGroupMetadata::RecordInstructions() {
const auto visitor = [this](HloInstruction* hlo) -> Status {
if (hlo->opcode() == HloOpcode::kWhile) {
@@ -346,26 +370,38 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
if (!ShapeUtil::Compatible(send_shape, recv_shape)) {
return FailedPrecondition("send/recv shapes do not match");
}
- const HloModule* send_module = channel.send->parent()->parent();
- const HloModule* send_done_module = channel.send_done->parent()->parent();
- if (send_module != send_done_module) {
+ auto send_device = GetInstructionDevice(*channel.send);
+ auto send_done_device = GetInstructionDevice(*channel.send_done);
+ if (!send_device) {
+ return FailedPrecondition("send instruction must have a device: %s",
+ channel.send->ToString().c_str());
+ }
+ if (!send_done_device) {
+ return FailedPrecondition("send_done instruction must have a device: %s",
+ channel.send_done->ToString().c_str());
+ }
+ if (*send_device != *send_done_device) {
return FailedPrecondition(
"send and send-done (channel=%lld) must be on the same device: %lld "
"vs. %lld",
- channel.id, GetModuleId(send_module), GetModuleId(send_done_module));
+ channel.id, *send_device, *send_done_device);
+ }
+ auto recv_device = GetInstructionDevice(*channel.recv);
+ auto recv_done_device = GetInstructionDevice(*channel.recv_done);
+ if (!recv_done_device) {
+ return FailedPrecondition("recv_done instruction must have a device: %s",
+ channel.recv_done->ToString().c_str());
}
- const HloModule* recv_module = channel.recv->parent()->parent();
- const HloModule* recv_done_module = channel.recv_done->parent()->parent();
- if (recv_module != recv_done_module) {
+ if (*recv_device != *recv_done_device) {
return FailedPrecondition(
"recv and recv-done (channel=%lld) must be on the same device: %lld "
"vs. %lld",
- channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module));
+ channel.id, *recv_device, *recv_done_device);
}
- if (send_module == recv_module) {
+ if (*send_device == *recv_device) {
return FailedPrecondition(
"send and recv (channel=%lld) must be on different devices: %lld",
- channel.id, GetModuleId(send_module));
+ channel.id, *send_device);
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 3ef4542f91..5f5bf27479 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -148,6 +149,12 @@ class HloModuleGroupMetadata {
// the module in the module vector.
int64 GetModuleId(const HloModule* module) const;
+ // Retrieves the device an instruction is assigned to. Either from the
+ // sharding information, or from the ordinal of the module the instruction
+ // is in.
+ tensorflow::gtl::optional<int64> GetInstructionDevice(
+ const HloInstruction& instruction) const;
+
// Returns the companion instructions for the given instruction.
//
// Precondition: IsCompanionWhile(instruction) is true.
@@ -231,6 +238,11 @@ class HloModuleGroupMetadata {
tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction>
tracked_instructions_;
+ // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of
+ // communicating instructions within the proper called computation(s).
+ tensorflow::gtl::FlatMap<HloInstruction*, std::vector<HloInstruction*>>
+ tracked_instructions_comms_;
+
// All channels in the module.
std::vector<Channel> channels_;
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index ac7cd2f2f5..1fe06ee0c0 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -69,6 +69,7 @@ namespace xla {
V(kCrossReplicaSum, "cross-replica-sum") \
V(kCustomCall, "custom-call") \
V(kDivide, "divide") \
+ V(kDomain, "domain") \
V(kDot, "dot") \
V(kDynamicSlice, "dynamic-slice") \
V(kDynamicUpdateSlice, "dynamic-update-slice") \
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 7708422ce1..58224ef870 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -123,6 +123,24 @@ std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
return index;
}
+StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
+ const Shape& shape) const {
+ if (IsTuple()) {
+ ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
+ int64 num_leaves = result.leaf_count();
+ TF_RET_CHECK(num_leaves == tuple_elements_.size())
+ << "Shape " << ShapeUtil::HumanString(shape) << " has " << num_leaves
+ << " leaf nodes while this sharding has " << tuple_elements_.size();
+ auto it = tuple_elements_.begin();
+ for (auto& index_to_sharding : result.leaves()) {
+ index_to_sharding.second = *it++;
+ }
+ return std::move(result);
+ } else {
+ return ShapeTree<HloSharding>(shape, *this);
+ }
+}
+
StatusOr<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) {
if (tuple_elements_.empty()) {
@@ -367,11 +385,8 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape,
Shape sub_shape = ShapeUtil::GetSubshape(shape, index);
ShapeTree<HloSharding> sub_shape_tree(sub_shape, Replicate());
sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {});
- if (ShapeUtil::IsTuple(sub_shape)) {
- return Tuple(sub_shape_tree);
- } else {
- return sub_shape_tree.element({});
- }
+ return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree)
+ : sub_shape_tree.element(ShapeIndex({}));
}
std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index e8bb06c8f7..f4a0fb626f 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -163,19 +163,9 @@ class HloSharding {
// tuple, if IsTuple, or a ShapeTree with a single element containing this
// sharding. Only the leaf elements are populated. This creates a new
// ShapeTree object so is not cheap.
+ StatusOr<ShapeTree<HloSharding>> AsShapeTree(const Shape& shape) const;
ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const {
- if (IsTuple()) {
- ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
- CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()),
- tuple_elements_.size());
- auto it = tuple_elements_.begin();
- for (auto& index_to_sharding : result.leaves()) {
- index_to_sharding.second = *it++;
- }
- return result;
- } else {
- return ShapeTree<HloSharding>(shape, *this);
- }
+ return AsShapeTree(shape).ValueOrDie();
}
// Retrieves the sub sharding at a given index, out of a tuple sharding.
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
new file mode 100644
index 0000000000..82cff2a4b7
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -0,0 +1,401 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+
+namespace xla {
+
+namespace {
+
+struct PassThrough {
+ PassThrough(HloInstruction* user, HloInstruction* operand)
+ : user(user), operand(operand) {}
+
+ HloInstruction* user = nullptr;
+ HloInstruction* operand = nullptr;
+};
+
+void SetDeviceSharding(HloInstruction* instruction, int64 device) {
+ VLOG(4) << " " << instruction->name() << " to device " << device;
+ instruction->set_device_sharding(device);
+}
+
+tensorflow::gtl::optional<int64> ShardingUniqueDevice(
+ const HloSharding& sharding) {
+ if (sharding.IsTileMaximal()) {
+ auto device = sharding.UniqueDevice();
+ if (device.ok()) {
+ return device.ValueOrDie();
+ }
+ }
+ return tensorflow::gtl::optional<int64>();
+}
+
+bool ShardingMatches(const HloSharding& sharding1,
+ const HloSharding& sharding2) {
+ auto device1 = ShardingUniqueDevice(sharding1);
+ if (device1) {
+ auto device2 = ShardingUniqueDevice(sharding2);
+ if (device2) {
+ return *device1 == *device2;
+ }
+ }
+ // Anything which is not tile maximal with unique device, gets a full sharding
+ // compare.
+ return sharding1 == sharding2;
+}
+
+// When we create domains, they are never "empty", where with empty we mean
+// that a kDomain instruction has as operand another kDomain instruction of the
+// same kind.
+// But when the HLO optimizations are run, empty domains can be created.
+// For example:
+//
+// Domain(device=None, device=0) ->
+// Tuple(device=0) ->
+// GTE(device=0) ->
+// Domain(device=0, device=None)
+//
+// In that case the tuple simplifier could create something like:
+//
+// Domain(device=None, device=0) -> Domain(device=0, device=None)
+//
+// Which is a so called empty domain.
+// In the case above, crossing an empty domain which was transiting through
+// device 0, requires the normalization phase to fixup the empty domain by
+// adding back a Tuple+GTE pair with the proper device.
+// One particular case where this can create problems is the result of the
+// entry computation, where the GTE assignments are used by TF to tell the
+// XLA where the results should be sent.
+std::vector<PassThrough> LocatePassThroughDomainLinks(
+ const DomainMetadata::Domain& domain) {
+ std::vector<PassThrough> pass_through;
+ for (HloInstruction* instruction : domain.enter_domains) {
+ CHECK(instruction->opcode() == HloOpcode::kDomain)
+ << "Instruction is not a kDomain: " << instruction->ToString();
+ for (HloInstruction* user : instruction->users()) {
+ if (user->opcode() == HloOpcode::kDomain &&
+ domain.exit_domains.count(user) != 0) {
+ pass_through.emplace_back(user, instruction);
+ VLOG(2) << "Found passthrough domain link:";
+ VLOG(2) << " " << user->ToString();
+ VLOG(2) << " " << instruction->ToString();
+ }
+ }
+ }
+ return pass_through;
+}
+
+Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
+ const HloSharding& sharding) {
+ for (auto& pass_through : LocatePassThroughDomainLinks(domain)) {
+ HloInstruction* tuple = pass_through.operand->parent()->AddInstruction(
+ HloInstruction::CreateTuple({pass_through.operand}));
+ HloInstruction* gte = pass_through.operand->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(pass_through.operand->shape(),
+ tuple, 0));
+ gte->set_sharding(sharding);
+ TF_RETURN_IF_ERROR(
+ pass_through.operand->ReplaceUseWith(pass_through.user, gte));
+ }
+ return Status::OK();
+}
+
+std::unique_ptr<HloSharding> CloneShardingForDomain(
+ const HloSharding& sharding) {
+ auto device = ShardingUniqueDevice(sharding);
+ if (!device) {
+ return MakeUnique<HloSharding>(sharding);
+ }
+ return MakeUnique<HloSharding>(HloSharding::AssignDevice(*device));
+}
+
+Status ApplyDomainDeviceSharding(const DomainMetadata::Domain& domain,
+ int64 device) {
+ VLOG(4) << "Applying device " << device << " sharding";
+ for (HloInstruction* instruction : domain.instructions) {
+ // We only change instructions without sharding, since otherwise we might
+ // mess up with eventual HLO passes which has knowledge of it.
+ if (!instruction->has_sharding()) {
+ SetDeviceSharding(instruction, device);
+ } else {
+ VLOG(4) << " " << instruction->name() << " already has sharding "
+ << instruction->sharding();
+ }
+ }
+ return Status::OK();
+}
+
+// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree.
+// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate()
+// sharding will be returned.
+ShapeTree<HloSharding> GetTupleSharding(HloInstruction* tuple) {
+ if (tuple->has_sharding()) {
+ return tuple->sharding().GetAsShapeTree(tuple->shape());
+ }
+ return ShapeTree<HloSharding>(tuple->shape(), HloSharding::Replicate());
+}
+
+// Retrieves the sharding of operand, asked from a user instruction which is
+// within domain. If operand is a kDomain, it means that sharding argument is
+// the operand sharding, otherwise the operand's own sharding will be returned.
+const HloSharding* GetOperandSharding(const HloInstruction* operand,
+ const DomainMetadata::Domain& domain,
+ const HloSharding& sharding) {
+ DCHECK_EQ(domain.reach_set.count(const_cast<HloInstruction*>(operand)), 1);
+ // Here the user of operand is within the domain instruction set, and since it
+ // is user of operand, we need to look into the enter_domains set. If this is
+ // not a kDomain within the user domains set, then return the operand
+ // sharding, if any.
+ if (operand->opcode() != HloOpcode::kDomain ||
+ domain.enter_domains.count(const_cast<HloInstruction*>(operand)) == 0) {
+ return operand->has_sharding() ? &operand->sharding() : nullptr;
+ }
+ // At this point operand is a kDomain of the currently processed domain, so we
+ // can refer to sharding as the domain sharding.
+ return &sharding;
+}
+
+// Tries to propagate the sharding information into the instructions that are
+// part of the domain, in a post order manner (operand propagate to user).
+StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
+ const HloSharding& sharding) {
+ int64 assigned = 0;
+ for (HloInstruction* instruction : domain.instructions) {
+ if (instruction->has_sharding()) {
+ continue;
+ }
+ if (instruction->opcode() == HloOpcode::kGetTupleElement) {
+ HloInstruction* tuple = instruction->mutable_operand(0);
+ const HloSharding* tuple_sharding =
+ GetOperandSharding(tuple, domain, sharding);
+ if (tuple_sharding != nullptr) {
+ TF_RET_CHECK(tuple_sharding->IsTuple()) << tuple->ToString();
+ HloSharding sub_sharding = tuple_sharding->GetSubSharding(
+ tuple->shape(), {instruction->tuple_index()});
+ VLOG(4) << " " << instruction->name() << " to sharding "
+ << sub_sharding;
+ instruction->set_sharding(sub_sharding);
+ ++assigned;
+ }
+ } else if (instruction->opcode() == HloOpcode::kTuple) {
+ int64 tuple_assigned = 0;
+ ShapeTree<HloSharding> shape_tree = GetTupleSharding(instruction);
+ for (int64 i = 0; i < instruction->operand_count(); ++i) {
+ const HloSharding* operand_sharding =
+ GetOperandSharding(instruction->operand(i), domain, sharding);
+ if (operand_sharding != nullptr &&
+ shape_tree.element({i}) != *operand_sharding) {
+ *shape_tree.mutable_element({i}) = *operand_sharding;
+ ++tuple_assigned;
+ }
+ }
+ if (tuple_assigned > 0) {
+ HloSharding tuple_sharding = HloSharding::Tuple(shape_tree);
+ VLOG(4) << " " << instruction->name() << " to sharding "
+ << tuple_sharding;
+ instruction->set_sharding(tuple_sharding);
+ ++assigned;
+ }
+ } else {
+ // If all the operand of the given instruction has the same single device
+ // assignment, assign that device to this instruction as well.
+ const HloSharding* common_sharding = nullptr;
+ for (const HloInstruction* operand : instruction->operands()) {
+ const HloSharding* operand_sharding =
+ GetOperandSharding(operand, domain, sharding);
+ if (operand_sharding != nullptr) {
+ if (common_sharding != nullptr &&
+ *common_sharding != *operand_sharding) {
+ common_sharding = nullptr;
+ break;
+ }
+ common_sharding = operand_sharding;
+ }
+ }
+ if (common_sharding != nullptr) {
+ VLOG(4) << " " << instruction->name() << " to sharding "
+ << *common_sharding;
+ instruction->set_sharding(*common_sharding);
+ ++assigned;
+ }
+ }
+ }
+ return assigned;
+}
+
+Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
+ const HloSharding& sharding) {
+ auto device = ShardingUniqueDevice(sharding);
+ if (device) {
+ // Shortcut the simple case. We have a unique device sharding, so we call
+ // the ApplyDomainDeviceSharding() API which will apply array or tuple
+ // shaped device sharding to the domain instructions.
+ return ApplyDomainDeviceSharding(domain, *device);
+ }
+ VLOG(1) << "Assigning non-trivial sharding " << sharding;
+ for (;;) {
+ TF_ASSIGN_OR_RETURN(int64 assigned,
+ ApplyDomainShardingPass(domain, sharding));
+ if (assigned == 0) {
+ break;
+ }
+ }
+ int64 unassigned = 0;
+ for (HloInstruction* instruction : domain.instructions) {
+ if (!instruction->has_sharding()) {
+ LOG(WARNING) << "Unassigned instruction: " << instruction->ToString();
+ ++unassigned;
+ }
+ }
+ // Should we error out if unassigned > 0?
+ return Status::OK();
+}
+
+// Creates a kDomain instruction to be placed between instruction and operand.
+// The kDomain instruction will be created only if the sharding differ between
+// the instruction and the operand.
+std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction,
+ HloInstruction* operand) {
+ const HloSharding* instruction_sharding =
+ instruction->has_sharding() ? &instruction->sharding() : nullptr;
+ const HloSharding* operand_sharding =
+ operand->has_sharding() ? &operand->sharding() : nullptr;
+ // No need for domain if they both have no sharding.
+ if (instruction_sharding == nullptr && operand_sharding == nullptr) {
+ return nullptr;
+ }
+ // No need for domain if they match.
+ if (instruction_sharding != nullptr && operand_sharding != nullptr &&
+ ShardingMatches(*instruction_sharding, *operand_sharding)) {
+ return nullptr;
+ }
+ std::unique_ptr<HloSharding> real_instruction_sharding;
+ std::unique_ptr<HloSharding> real_operand_sharding;
+ if (instruction_sharding != nullptr) {
+ real_instruction_sharding = CloneShardingForDomain(*instruction_sharding);
+ }
+ if (operand_sharding != nullptr) {
+ real_operand_sharding = CloneShardingForDomain(*operand_sharding);
+ }
+ VLOG(3) << "Creating domain:";
+ VLOG(3) << " Instruction: " << instruction->name();
+ VLOG(3) << " Operand: " << operand->name();
+ VLOG(3) << " User side sharding: "
+ << (real_instruction_sharding != nullptr
+ ? real_instruction_sharding->ToString()
+ : "None");
+ VLOG(3) << " Operand side sharding: "
+ << (real_operand_sharding != nullptr
+ ? real_operand_sharding->ToString()
+ : "None");
+
+ std::unique_ptr<DomainMetadata> operand_side_metadata =
+ MakeUnique<ShardingMetadata>(std::move(real_operand_sharding));
+ std::unique_ptr<DomainMetadata> user_side_metadata =
+ MakeUnique<ShardingMetadata>(std::move(real_instruction_sharding));
+ return HloInstruction::CreateDomain(operand->shape(), operand,
+ std::move(operand_side_metadata),
+ std::move(user_side_metadata));
+}
+
+StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ // If we are here, all the instructions being passed had the same sharding
+ // (or no sharding), by the means of the ShardingMatches() API.
+ // As such, no kDomain was inserted, and here we are asked to extract the
+ // original common sharding.
+ // All the instructions passed to this API are part of the same computation.
+ const HloSharding* sharding = nullptr;
+ for (HloInstruction* instruction : instructions) {
+ if (instruction->has_sharding()) {
+ if (sharding == nullptr) {
+ sharding = &instruction->sharding();
+ } else {
+ TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding()))
+ << "Sharding " << *sharding << " does not match the one in "
+ << instruction->ToString();
+ }
+ }
+ }
+ if (sharding == nullptr) {
+ return std::unique_ptr<HloSharding>();
+ }
+ VLOG(4) << "Extracted sharding is " << *sharding;
+ return CloneShardingForDomain(*sharding);
+}
+
+} // namespace
+
+std::unique_ptr<DomainMetadata> ShardingMetadata::Clone() const {
+ std::unique_ptr<HloSharding> sharding;
+ if (sharding_ != nullptr) {
+ sharding = MakeUnique<HloSharding>(*sharding_);
+ }
+ return MakeUnique<ShardingMetadata>(std::move(sharding));
+}
+
+bool ShardingMetadata::Matches(const DomainMetadata& other) const {
+ const ShardingMetadata* other_ptr =
+ dynamic_cast<const ShardingMetadata*>(&other);
+ if (other_ptr == nullptr) {
+ // If other is not a ShardingMetadata, then it is clearly a no match.
+ return false;
+ }
+ if (sharding_ == nullptr) {
+ return other_ptr->sharding_ == nullptr;
+ }
+ return other_ptr->sharding_ != nullptr
+ ? ShardingMatches(*sharding_, *other_ptr->sharding_)
+ : false;
+}
+
+string ShardingMetadata::ToString() const {
+ return sharding_ != nullptr ? sharding_->ToString() : "None";
+}
+
+Status ShardingMetadata::NormalizeInstructions(
+ const DomainMetadata::Domain& domain) const {
+ if (sharding_ != nullptr) {
+ VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":";
+ TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_));
+ TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_));
+ }
+ return Status::OK();
+}
+
+Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloSharding> sharding,
+ ExtractOriginalCommonSharding(domain.instructions));
+ if (sharding != nullptr) {
+ VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString()
+ << ":";
+ TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
+ } else {
+ VLOG(1) << "Unable to find common sharding";
+ }
+ return Status::OK();
+}
+
+std::unique_ptr<HloInstruction> CreateShardingDomain(
+ HloInstruction* instruction, HloInstruction* operand) {
+ return CreateDomain(instruction, operand);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
new file mode 100644
index 0000000000..ec162c3490
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_
+
+#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace xla {
+
+// A DomainMetadata implementation that internally wraps a sharding attribute.
+class ShardingMetadata : public DomainMetadata {
+ public:
+ explicit ShardingMetadata(std::unique_ptr<HloSharding> sharding)
+ : sharding_(std::move(sharding)) {}
+
+ std::unique_ptr<DomainMetadata> Clone() const override;
+
+ tensorflow::StringPiece Kind() const override { return KindName(); }
+
+ bool Matches(const DomainMetadata& other) const override;
+
+ string ToString() const override;
+
+ Status NormalizeInstructions(
+ const DomainMetadata::Domain& domain) const override;
+
+ static tensorflow::StringPiece KindName() { return "sharding"; }
+
+ private:
+ std::unique_ptr<HloSharding> sharding_;
+};
+
+// Within a set of instructions which had common sharding attributes before
+// entring the HLO passes pipeline, apply sharding heuristics and normalize the
+// instructions whose sharding deviates from the one which is inferred as to be
+// the original one.
+// Policy wise, HLO passes are allowed to create new unassigned instructions,
+// but if they do create assigned ones, they have to conform to the ones around.
+Status NormalizeShardingDomain(const DomainMetadata::Domain& domain);
+
+// Given an HLO graph edge between instruction and one of its operands, creates
+// a ShardingMetadata based kDomain instruction if the sharding between
+// instruction and operand changes. Returns nullptr if there is no need for a
+// domain separation.
+std::unique_ptr<HloInstruction> CreateShardingDomain(
+ HloInstruction* instruction, HloInstruction* operand);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 7d6d0d9eaf..9cfd8a9bf7 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -376,6 +376,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
case HloOpcode::kConstant:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kCustomCall:
+ case HloOpcode::kDomain:
case HloOpcode::kFusion:
case HloOpcode::kGetTupleElement:
case HloOpcode::kInfeed:
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 1912b8f2c7..429c850343 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -118,6 +118,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kCustomCall:
case HloOpcode::kDivide:
+ case HloOpcode::kDomain:
case HloOpcode::kDot:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index 6aca6ba385..f410921b4b 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -125,6 +125,12 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) {
return Status::OK();
}
+Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) {
+ // A kDomain instruction aliases its operand. That is, the buffer of its
+ // result *is* the buffer of its operand.
+ return Status::OK();
+}
+
Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) {
// RecvDone doesn't create a new buffer but rather aliases its input (Recv)
// tuple element at {0} to its output.
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
index f4c63dd86b..b5ef396787 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
@@ -59,6 +59,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
Status HandleTuple(HloInstruction* tuple) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleDomain(HloInstruction* domain) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 3500978bdd..d624f548b1 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -316,7 +316,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
/* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
HloOpcode opcode, const Shape& shape) {
// There is no copy operation at the proto level, so handle copy explicitly.
- if (opcode == HloOpcode::kCopy) {
+ // A domain shape is the same as the input one.
+ if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
return shape;
}
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 8cb654493c..bb634e6573 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -273,6 +273,14 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
return Status::OK();
}
+Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) {
+ // A kDomain instruction aliases its operand. That is, the buffer of its
+ // result *is* the buffer of its operand, so just copy the operands points-to
+ // set.
+ CreateCopiedPointsToSet(domain, domain->operand(0));
+ return Status::OK();
+}
+
Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) {
// A kSlice instruction aliases its operand if the backend lowers it to an
// in-place implementation.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index 1ac7130136..c0d8241480 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -248,6 +248,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
Status HandleTuple(HloInstruction* tuple) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleDomain(HloInstruction* domain) override;
Status HandleSlice(HloInstruction* slice) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 37c94ac543..5b14953ebb 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -222,6 +222,9 @@ class ShapeTree {
/*iterate_leaves_only=*/false);
}
+ // Returns the number of leaf nodes in the tree.
+ int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
+
// Recursively traverses the shape and calls the given function at each
// element. The function has the following arguments:
//
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 2cdee30340..e8a28d76e9 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -880,6 +880,27 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
return !IsTuple(GetSubshape(shape, index));
}
+/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
+ int64 count = 0;
+ ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) {
+ if (IsLeafIndex(shape, index)) {
+ ++count;
+ }
+ });
+ return count;
+}
+
+/* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
+ const Shape& shape) {
+ std::vector<IndexedShape> leaves;
+ ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
+ if (IsLeafIndex(shape, index)) {
+ leaves.emplace_back(index, sub_shape);
+ }
+ });
+ return leaves;
+}
+
/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
std::vector<int64> dimension_sizes;
std::vector<int64> degenerate_dimensions;
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index cf40068b33..9df31d5d21 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -154,6 +154,16 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
// properties, which do invariant checks before / after the operation.
class ShapeUtil {
public:
+ // Data structure which describes the coordinates and the shape, of a tuple
+ // shaped sub-shape.
+ struct IndexedShape {
+ IndexedShape() = default;
+ IndexedShape(ShapeIndex index, Shape shape)
+ : index(std::move(index)), shape(std::move(shape)) {}
+ ShapeIndex index;
+ Shape shape;
+ };
+
// Returns the number of elements are contained within the provided shape;
// e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes
// may not actually be able to store this number of elements. See
@@ -465,6 +475,13 @@ class ShapeUtil {
// shape.
static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
+ // Returns the number of leaves in the shape.
+ static int64 GetLeafCount(const Shape& shape);
+
+ // Retrieves all the leaf shapes and their indexes, in the order walked by
+ // the ForEachSubshape() API.
+ static std::vector<IndexedShape> GetLeafShapes(const Shape& shape);
+
// Calls the given visitor function for each subshape of the given shape.
// Subshapes are visited in DFS pre-order starting with the entire shape
// (index {}).
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 76c870bc98..134978d21f 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -486,6 +486,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kClz:
case HloOpcode::kCopy:
case HloOpcode::kCos:
+ case HloOpcode::kDomain:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kImag: