diff options
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 30 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/call_graph.cc | 258 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/call_graph.h | 175 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/call_graph_test.cc | 290 |
4 files changed, 753 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index f5a67124e0..ccd84de654 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -130,6 +130,36 @@ cc_test( ) cc_library( + name = "call_graph", + srcs = ["call_graph.cc"], + hdrs = ["call_graph.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "call_graph_test", + srcs = ["call_graph_test.cc"], + deps = [ + ":call_graph", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( name = "user_computation", srcs = ["user_computation.cc"], hdrs = ["user_computation.h"], diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc new file mode 100644 index 0000000000..ed140f728d --- /dev/null +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -0,0 +1,258 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/call_graph.h" + +#include <queue> + +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +using ::tensorflow::strings::Appendf; +using ::tensorflow::strings::StrCat; + +string CallContextToString(CallContext context) { + switch (context) { + case CallContext::kNone: + return "kNone"; + case CallContext::kSequential: + return "kSequential"; + case CallContext::kParallel: + return "kParallel"; + case CallContext::kBoth: + return "kBoth"; + } +} + +std::ostream& operator<<(std::ostream& out, const CallContext& context) { + out << CallContextToString(context); + return out; +} + +string CallSite::ToString() const { + return StrCat(instruction->name(), " calls ", called_computation->name(), + ", ", CallContextToString(context)); +} + +CallGraphNode::CallGraphNode(HloComputation* computation) + : computation_(computation) {} + +void CallGraphNode::AddCallSite(const CallSite& callsite) { + callsites_.push_back(callsite); + HloComputation* callee = callsite.called_computation; + if (callee_set_.count(callee) == 0) { + callees_.push_back(callee); + callee_set_.insert(callee); + } +} + +void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) { + caller_callsites_.push_back(caller_callsite); + HloComputation* caller = caller_callsite.instruction->parent(); + if (caller_set_.count(caller) == 0) { + callers_.push_back(caller); + caller_set_.insert(caller); + } +} + +void CallGraphNode::AddCallSitesInInstruction(HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kCall: + AddCallSite( + {instruction, instruction->to_apply(), CallContext::kSequential}); + break; + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + AddCallSite( + {instruction, instruction->to_apply(), CallContext::kParallel}); + break; + case HloOpcode::kSelectAndScatter: + AddCallSite({instruction, instruction->select(), CallContext::kParallel}); + AddCallSite( + {instruction, instruction->scatter(), CallContext::kParallel}); + break; + case HloOpcode::kWhile: + AddCallSite({instruction, instruction->while_condition(), + CallContext::kSequential}); + AddCallSite( + {instruction, instruction->while_body(), CallContext::kSequential}); + break; + case HloOpcode::kFusion: + for (const auto& fused_instruction : instruction->fused_instructions()) { + AddCallSitesInInstruction(fused_instruction.get()); + } + break; + default: + break; + } +} + +CallGraph::CallGraph(const HloModule* module) : module_(module) {} + +StatusOr<const CallGraphNode*> CallGraph::GetNode( + const HloComputation* computation) const { + auto it = node_indices_.find(computation); + TF_RET_CHECK(it != node_indices_.end()); + return &nodes_[it->second]; +} + +StatusOr<CallGraphNode*> CallGraph::GetNode(const HloComputation* computation) { + auto it = node_indices_.find(computation); + TF_RET_CHECK(it != node_indices_.end()); + return &nodes_[it->second]; +} + +namespace { + +// Returns the call context of a computation which is called from contexts 'a' +// and 'b'. +CallContext UnionContexts(CallContext a, CallContext b) { + if (a == CallContext::kNone) { + return b; + } else if (b == CallContext::kNone) { + return a; + } else if (a == b) { + return a; + } else { + // Contexts are different and neither is kNone, ie one is kSequential and + // the other is kParallel. + return CallContext::kBoth; + } +} + +} // namespace + +Status CallGraph::SetCallContexts() { + std::queue<CallGraphNode*> worklist; + + // Initialize worklist with all roots of the call graph (computations without + // callers). + for (const std::unique_ptr<HloComputation>& computation : + module_->computations()) { + TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get())); + if (node->callers().empty()) { + node->set_context(CallContext::kSequential); + worklist.push(node); + } + } + + while (!worklist.empty()) { + CallGraphNode* node = worklist.front(); + worklist.pop(); + + for (const CallSite& callsite : node->callsites()) { + TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, + GetNode(callsite.called_computation)); + + // Update context of callee computation based on the callsite and its + // current context. + CallContext context_to_add; + if (callsite.context == CallContext::kParallel) { + context_to_add = CallContext::kParallel; + } else { + TF_RET_CHECK(callsite.context == CallContext::kSequential); + context_to_add = node->context(); + } + CallContext new_context = + UnionContexts(context_to_add, callee_node->context()); + + if (new_context != callee_node->context()) { + // Context of computation has been changed so add node to worklist. + callee_node->set_context(new_context); + worklist.push(callee_node); + } + } + } + + // No node should have a kNone calling context. + for (const std::unique_ptr<HloComputation>& computation : + module_->computations()) { + TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get())); + TF_RET_CHECK(node->context() != CallContext::kNone); + } + return Status::OK(); +} + +/* static */ +StatusOr<CallGraph> CallGraph::Build(const HloModule* module) { + CallGraph call_graph(module); + + // Construct nodes of the call graph and populate the callsites. + for (const std::unique_ptr<HloComputation>& computation : + module->computations()) { + auto it_added = call_graph.node_indices_.insert( + {computation.get(), call_graph.nodes_.size()}); + // All computation should be unique, so the computation should not already + // exist in the map. + TF_RET_CHECK(it_added.second); + call_graph.nodes_.emplace_back(computation.get()); + + // Add all callsites in this computation. + for (const std::unique_ptr<HloInstruction>& instruction : + computation->instructions()) { + call_graph.nodes_.back().AddCallSitesInInstruction(instruction.get()); + } + } + + // Add caller callsites to each node. + for (const std::unique_ptr<HloComputation>& computation : + module->computations()) { + TF_ASSIGN_OR_RETURN(CallGraphNode * caller_node, + call_graph.GetNode(computation.get())); + for (const CallSite& callsite : caller_node->callsites()) { + // Add caller callsites. + TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, + call_graph.GetNode(callsite.called_computation)); + callee_node->AddCallerCallSite(callsite); + } + } + + TF_RETURN_IF_ERROR(call_graph.SetCallContexts()); + + XLA_VLOG_LINES(1, call_graph.ToString()); + + return std::move(call_graph); +} + +string CallGraph::ToString() const { + string out; + Appendf(&out, "Call graph for module %s:\n", module_->name().c_str()); + for (const CallGraphNode& node : nodes()) { + Appendf(&out, "Computation %s:\n", node.computation()->name().c_str()); + Appendf(&out, " calls:\n"); + for (const HloComputation* callee : node.callees()) { + Appendf(&out, " %s\n", callee->name().c_str()); + } + Appendf(&out, " called by:\n"); + for (const HloComputation* caller : node.callers()) { + Appendf(&out, " %s\n", caller->name().c_str()); + } + Appendf(&out, " callsites:\n"); + for (const CallSite& callsite : node.callsites()) { + Appendf(&out, " %s\n", callsite.ToString().c_str()); + } + } + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h new file mode 100644 index 0000000000..f9291684bf --- /dev/null +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -0,0 +1,175 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Call graph for an HLO module. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ + +#include <ostream> + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// The context in which a computation is called by another computation. +enum class CallContext { + // In a parallel contex the computation is applied to each element of the + // array argument(s). kMap and kReduce instructions call computations in + // parallel context. + kParallel, + + // In a sequential context the computation is applied to the entire argument + // shape(s). kCall and kWhile (body and condition) call computations in + // sequential context. + kSequential, + + // A computation is called from both a parallel and sequential context. + kBoth, + + // During call graph construction kNone is used to indicate that the context + // has not been determined. This is the top value for the context + // lattice. After construction, no call sites or call graph nodes should have + // this value. + kNone +}; + +string CallContextToString(CallContext context); +std::ostream& operator<<(std::ostream& out, const CallContext& context); + +// Represents an instruction calling a particular computation in an HLO +// module. Some instructions such as kWhile can call more than one computation +// and may be represented with more than one CallSite, one for each computation +// called. +struct CallSite { + // The calling instruction. + HloInstruction* instruction; + + // The computation the instruction is calling. + HloComputation* called_computation; + + // The context in which the computation is called. + CallContext context; + + string ToString() const; +}; + +// A node in the call graph representing an HLO computation. +class CallGraphNode { + public: + CallGraphNode(HloComputation* computation); + + // Return the computation represented by this call graph node. + HloComputation* computation() const { return computation_; } + + // Return the call sites in this computation. These are the instructions in + // this computation which call other computations. + const std::vector<CallSite>& callsites() const { return callsites_; } + + // Return the computations called by this computation. + const std::vector<HloComputation*>& callees() const { return callees_; } + + // Return the call sites in other computations which call this computation. + const std::vector<CallSite>& caller_callsites() const { + return caller_callsites_; + } + + // Return the computations which call this computation. + const std::vector<HloComputation*>& callers() const { return callers_; } + + // Return or set the context in which this computation is called. + CallContext context() const { return context_; } + void set_context(CallContext value) { context_ = value; } + + // Add a callsite which calls this computation. Updates callers to include the + // calling computation. + void AddCallerCallSite(const CallSite& caller_callsite); + + // Add a call site to this computation. Updates callees to include the called + // computation. + void AddCallSite(const CallSite& callsite); + + // Add all the call sites (if any) for this instruction. Instruction must be + // an instruction in this node's computation. + void AddCallSitesInInstruction(HloInstruction* instruction); + + string ToString() const; + + private: + // Computation represented by this call graph node. + HloComputation* computation_; + + // The computations called by this computation. The vector is used for a + // stable ordering and the set enables fast membership testing. + std::vector<HloComputation*> callees_; + std::unordered_set<HloComputation*> callee_set_; + + // The computations which call this computation. The vector is used for a + // stable ordering and the set enables fast membership testing. + std::vector<HloComputation*> callers_; + std::unordered_set<HloComputation*> caller_set_; + + // The call sites in this computation + std::vector<CallSite> callsites_; + + // The call sites in other computations which call this computation. + std::vector<CallSite> caller_callsites_; + + // The context in which this computation is called. + CallContext context_ = CallContext::kNone; +}; + +// The call graph for an HLO module. The graph includes a node for each +// computation in the module. +class CallGraph { + public: + // Build and return a call graph for the given HLO module. + static StatusOr<CallGraph> Build(const HloModule* module); + + // Public default constructor required for StatusOr<CallGraph>. + CallGraph() = default; + + // Return the node associated with the given computation. + StatusOr<const CallGraphNode*> GetNode( + const HloComputation* computation) const; + StatusOr<CallGraphNode*> GetNode(const HloComputation* computation); + + // Return the vector of all nodes in the call graph. + const std::vector<CallGraphNode>& nodes() const { return nodes_; } + + string ToString() const; + + private: + CallGraph(const HloModule* module); + + // Sets the call contexts for every node in the graph. + Status SetCallContexts(); + + const HloModule* module_ = nullptr; + + // Vector of all nodes in the call graph. + std::vector<CallGraphNode> nodes_; + + // Map from HLO computation to the index of the corresponding call graph node + // in nodes_. + std::unordered_map<const HloComputation*, int64> node_indices_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc new file mode 100644 index 0000000000..c63a1bef4e --- /dev/null +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -0,0 +1,290 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/call_graph.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class CallGraphTest : public HloTestBase { + protected: + // Build and return a trivial computation taking and returning a scalar. + std::unique_ptr<HloComputation> MakeScalarComputation() { + HloComputation::Builder builder(TestName() + ".ScalarComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0)); + return builder.Build(); + } + + // Build and return a computation which takes a scalar and maps (kMap) the + // given computation to the value 'callsites' number of times. + std::unique_ptr<HloComputation> MakeMappingComputation( + HloComputation* map_computation, int64 callsites) { + HloComputation::Builder builder(TestName() + ".MappingComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* last_value = param0; + for (int64 i = 0; i < callsites; ++i) { + last_value = builder.AddInstruction(HloInstruction::CreateMap( + kScalarShape, {last_value}, map_computation)); + } + return builder.Build(); + } + + // Build and return a computation which takes a scalar and calls (kCall) the + // given computation with value 'callsites' number of times. + std::unique_ptr<HloComputation> MakeCallingComputation( + HloComputation* map_computation, int64 callsites) { + HloComputation::Builder builder(TestName() + ".CallingComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* last_value = param0; + for (int64 i = 0; i < callsites; ++i) { + last_value = builder.AddInstruction(HloInstruction::CreateCall( + kScalarShape, {last_value}, map_computation)); + } + return builder.Build(); + } + + // Build and return a computation which takes a scalar and returns a PRED + // value. + std::unique_ptr<HloComputation> MakeConditionComputation() { + HloComputation::Builder builder(TestName() + ".ConditionComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); + return builder.Build(); + } + + const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(CallGraphTest, SingletonComputation) { + // Test the call graph of a module with a single computation. + HloModule module(TestName()); + HloComputation* computation = + module.AddEntryComputation(MakeScalarComputation()); + TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module)); + EXPECT_EQ(1, call_graph.nodes().size()); + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* node, + call_graph.GetNode(computation)); + EXPECT_EQ(computation, node->computation()); + EXPECT_TRUE(node->callsites().empty()); + EXPECT_TRUE(node->callees().empty()); + EXPECT_TRUE(node->caller_callsites().empty()); + EXPECT_TRUE(node->callers().empty()); + EXPECT_EQ(CallContext::kSequential, node->context()); +} + +TEST_F(CallGraphTest, UnreachableComputation) { + // Test the call graph of a module with an entry computation and an + // unreachable computation. + HloModule module(TestName()); + HloComputation* entry_computation = + module.AddEntryComputation(MakeScalarComputation()); + HloComputation* unreachable_computation = + module.AddEmbeddedComputation(MakeScalarComputation()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module)); + EXPECT_EQ(2, call_graph.nodes().size()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, + call_graph.GetNode(entry_computation)); + EXPECT_EQ(entry_computation, entry_node->computation()); + EXPECT_EQ(CallContext::kSequential, entry_node->context()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* unreachable_node, + call_graph.GetNode(unreachable_computation)); + EXPECT_EQ(unreachable_computation, unreachable_node->computation()); + EXPECT_EQ(CallContext::kSequential, unreachable_node->context()); +} + +TEST_F(CallGraphTest, ParallelComputation) { + // Test a call graph of a module with an entry computation which calls another + // computation in a parallel context via kMap. + HloModule module(TestName()); + HloComputation* map_computation = + module.AddEmbeddedComputation(MakeScalarComputation()); + HloComputation* entry_computation = module.AddEmbeddedComputation( + MakeMappingComputation(map_computation, /*callsites=*/5)); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module)); + EXPECT_EQ(2, call_graph.nodes().size()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, + call_graph.GetNode(entry_computation)); + EXPECT_EQ(entry_computation, entry_node->computation()); + EXPECT_EQ(CallContext::kSequential, entry_node->context()); + EXPECT_EQ(5, entry_node->callsites().size()); + EXPECT_EQ(1, entry_node->callees().size()); + EXPECT_TRUE(entry_node->caller_callsites().empty()); + EXPECT_TRUE(entry_node->callers().empty()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* map_node, + call_graph.GetNode(map_computation)); + EXPECT_EQ(map_computation, map_node->computation()); + EXPECT_EQ(CallContext::kParallel, map_node->context()); + EXPECT_TRUE(map_node->callsites().empty()); + EXPECT_TRUE(map_node->callees().empty()); + EXPECT_EQ(5, map_node->caller_callsites().size()); + EXPECT_EQ(1, map_node->callers().size()); +} + +TEST_F(CallGraphTest, SequentialComputations) { + // Test a call graph of a module with an entry computation which calls another + // computation in a sequential context via kCall. + HloModule module(TestName()); + HloComputation* called_computation = + module.AddEmbeddedComputation(MakeScalarComputation()); + HloComputation* entry_computation = module.AddEmbeddedComputation( + MakeCallingComputation(called_computation, /*callsites=*/3)); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module)); + EXPECT_EQ(2, call_graph.nodes().size()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, + call_graph.GetNode(entry_computation)); + EXPECT_EQ(entry_computation, entry_node->computation()); + EXPECT_EQ(CallContext::kSequential, entry_node->context()); + EXPECT_EQ(3, entry_node->callsites().size()); + EXPECT_EQ(1, entry_node->callees().size()); + EXPECT_TRUE(entry_node->caller_callsites().empty()); + EXPECT_TRUE(entry_node->callers().empty()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* called_node, + call_graph.GetNode(called_computation)); + EXPECT_EQ(called_computation, called_node->computation()); + EXPECT_EQ(CallContext::kSequential, called_node->context()); + EXPECT_TRUE(called_node->callsites().empty()); + EXPECT_TRUE(called_node->callees().empty()); + EXPECT_EQ(3, called_node->caller_callsites().size()); + EXPECT_EQ(1, called_node->callers().size()); +} + +TEST_F(CallGraphTest, ContextBothComputations) { + // Test a call graph of a module with an entry computation which calls another + // computation in both a parallel and sequential context. + HloModule module(TestName()); + HloComputation* subcomputation = + module.AddEmbeddedComputation(MakeScalarComputation()); + + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(kScalarShape, {param0}, subcomputation)); + HloInstruction* map = builder.AddInstruction( + HloInstruction::CreateMap(kScalarShape, {call}, subcomputation)); + HloComputation* entry_computation = + module.AddEmbeddedComputation(builder.Build()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module)); + EXPECT_EQ(2, call_graph.nodes().size()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, + call_graph.GetNode(entry_computation)); + EXPECT_EQ(entry_computation, entry_node->computation()); + EXPECT_EQ(2, entry_node->callsites().size()); + + const CallSite& call_callsite = entry_node->callsites()[0]; + EXPECT_EQ(call, call_callsite.instruction); + EXPECT_EQ(subcomputation, call_callsite.called_computation); + EXPECT_EQ(CallContext::kSequential, call_callsite.context); + + const CallSite& map_callsite = entry_node->callsites()[1]; + EXPECT_EQ(map, map_callsite.instruction); + EXPECT_EQ(subcomputation, map_callsite.called_computation); + EXPECT_EQ(CallContext::kParallel, map_callsite.context); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* sub_node, + call_graph.GetNode(subcomputation)); + EXPECT_EQ(CallContext::kBoth, sub_node->context()); +} + +TEST_F(CallGraphTest, ComplexGraph) { + // Test a call graph of a module with several computation called in various + // contexts. The call graph looks like: + // + // entry + // / | + // a | + // / | \ | + // b | cond + // \ | + // c + // + // Calls are made via kCall, kWhile, and kMap instructions. + HloModule module(TestName()); + HloComputation* cond_computation = + module.AddEmbeddedComputation(MakeConditionComputation()); + HloComputation* c_computation = + module.AddEmbeddedComputation(MakeScalarComputation()); + HloComputation* b_computation = module.AddEmbeddedComputation( + MakeMappingComputation(c_computation, /*callsites=*/1)); + + HloComputation* computation_a; + { + HloComputation::Builder builder(TestName() + ".a"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(kScalarShape, {param0}, c_computation)); + builder.AddInstruction(HloInstruction::CreateWhile( + kScalarShape, cond_computation, b_computation, call)); + computation_a = module.AddEmbeddedComputation(builder.Build()); + } + + HloComputation* entry_computation; + { + HloComputation::Builder builder(TestName() + ".entry"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + builder.AddInstruction(HloInstruction::CreateWhile( + kScalarShape, cond_computation, computation_a, param0)); + entry_computation = module.AddEntryComputation(builder.Build()); + } + + TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module)); + EXPECT_EQ(5, call_graph.nodes().size()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, + call_graph.GetNode(entry_computation)); + // Entry computation has one while instruction (two callsites). + EXPECT_EQ(2, entry_node->callsites().size()); + EXPECT_EQ(CallContext::kSequential, entry_node->context()); + + TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* c_node, + call_graph.GetNode(c_computation)); + EXPECT_TRUE(c_node->callsites().empty()); + EXPECT_EQ(2, c_node->callers().size()); + EXPECT_EQ(CallContext::kBoth, c_node->context()); +} + +} // namespace +} // namespace xla |