aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-02-27 11:05:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-27 11:30:48 -0800
commit2dafcef71bdb03b19b37ae80fe45575c3639e4e3 (patch)
tree2f138a64e3a9ad1db196b328431d88f7ad2037bb
parentb83db97a812edc614e6891a4fd69eaf422debea3 (diff)
[XLA] Add HLO call graph.
Add an HLO call graph class. The object includes a node for each computation and forward and backwards links between them. Node include the calling context: either parallel (eg, kMap) or sequential (eg, kCall). The class is not used anywhere yet, but there are numerous potential uses throughout the HLO level. Change: 148669295
-rw-r--r--tensorflow/compiler/xla/service/BUILD30
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc258
-rw-r--r--tensorflow/compiler/xla/service/call_graph.h175
-rw-r--r--tensorflow/compiler/xla/service/call_graph_test.cc290
4 files changed, 753 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f5a67124e0..ccd84de654 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -130,6 +130,36 @@ cc_test(
)
cc_library(
+ name = "call_graph",
+ srcs = ["call_graph.cc"],
+ hdrs = ["call_graph.h"],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "call_graph_test",
+ srcs = ["call_graph_test.cc"],
+ deps = [
+ ":call_graph",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
name = "user_computation",
srcs = ["user_computation.cc"],
hdrs = ["user_computation.h"],
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
new file mode 100644
index 0000000000..ed140f728d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -0,0 +1,258 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/call_graph.h"
+
+#include <queue>
+
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+using ::tensorflow::strings::Appendf;
+using ::tensorflow::strings::StrCat;
+
+string CallContextToString(CallContext context) {
+ switch (context) {
+ case CallContext::kNone:
+ return "kNone";
+ case CallContext::kSequential:
+ return "kSequential";
+ case CallContext::kParallel:
+ return "kParallel";
+ case CallContext::kBoth:
+ return "kBoth";
+ }
+}
+
+std::ostream& operator<<(std::ostream& out, const CallContext& context) {
+ out << CallContextToString(context);
+ return out;
+}
+
+string CallSite::ToString() const {
+ return StrCat(instruction->name(), " calls ", called_computation->name(),
+ ", ", CallContextToString(context));
+}
+
+CallGraphNode::CallGraphNode(HloComputation* computation)
+ : computation_(computation) {}
+
+void CallGraphNode::AddCallSite(const CallSite& callsite) {
+ callsites_.push_back(callsite);
+ HloComputation* callee = callsite.called_computation;
+ if (callee_set_.count(callee) == 0) {
+ callees_.push_back(callee);
+ callee_set_.insert(callee);
+ }
+}
+
+void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
+ caller_callsites_.push_back(caller_callsite);
+ HloComputation* caller = caller_callsite.instruction->parent();
+ if (caller_set_.count(caller) == 0) {
+ callers_.push_back(caller);
+ caller_set_.insert(caller);
+ }
+}
+
+void CallGraphNode::AddCallSitesInInstruction(HloInstruction* instruction) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kCall:
+ AddCallSite(
+ {instruction, instruction->to_apply(), CallContext::kSequential});
+ break;
+ case HloOpcode::kMap:
+ case HloOpcode::kReduce:
+ case HloOpcode::kReduceWindow:
+ AddCallSite(
+ {instruction, instruction->to_apply(), CallContext::kParallel});
+ break;
+ case HloOpcode::kSelectAndScatter:
+ AddCallSite({instruction, instruction->select(), CallContext::kParallel});
+ AddCallSite(
+ {instruction, instruction->scatter(), CallContext::kParallel});
+ break;
+ case HloOpcode::kWhile:
+ AddCallSite({instruction, instruction->while_condition(),
+ CallContext::kSequential});
+ AddCallSite(
+ {instruction, instruction->while_body(), CallContext::kSequential});
+ break;
+ case HloOpcode::kFusion:
+ for (const auto& fused_instruction : instruction->fused_instructions()) {
+ AddCallSitesInInstruction(fused_instruction.get());
+ }
+ break;
+ default:
+ break;
+ }
+}
+
+CallGraph::CallGraph(const HloModule* module) : module_(module) {}
+
+StatusOr<const CallGraphNode*> CallGraph::GetNode(
+ const HloComputation* computation) const {
+ auto it = node_indices_.find(computation);
+ TF_RET_CHECK(it != node_indices_.end());
+ return &nodes_[it->second];
+}
+
+StatusOr<CallGraphNode*> CallGraph::GetNode(const HloComputation* computation) {
+ auto it = node_indices_.find(computation);
+ TF_RET_CHECK(it != node_indices_.end());
+ return &nodes_[it->second];
+}
+
+namespace {
+
+// Returns the call context of a computation which is called from contexts 'a'
+// and 'b'.
+CallContext UnionContexts(CallContext a, CallContext b) {
+ if (a == CallContext::kNone) {
+ return b;
+ } else if (b == CallContext::kNone) {
+ return a;
+ } else if (a == b) {
+ return a;
+ } else {
+ // Contexts are different and neither is kNone, ie one is kSequential and
+ // the other is kParallel.
+ return CallContext::kBoth;
+ }
+}
+
+} // namespace
+
+Status CallGraph::SetCallContexts() {
+ std::queue<CallGraphNode*> worklist;
+
+ // Initialize worklist with all roots of the call graph (computations without
+ // callers).
+ for (const std::unique_ptr<HloComputation>& computation :
+ module_->computations()) {
+ TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get()));
+ if (node->callers().empty()) {
+ node->set_context(CallContext::kSequential);
+ worklist.push(node);
+ }
+ }
+
+ while (!worklist.empty()) {
+ CallGraphNode* node = worklist.front();
+ worklist.pop();
+
+ for (const CallSite& callsite : node->callsites()) {
+ TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node,
+ GetNode(callsite.called_computation));
+
+ // Update context of callee computation based on the callsite and its
+ // current context.
+ CallContext context_to_add;
+ if (callsite.context == CallContext::kParallel) {
+ context_to_add = CallContext::kParallel;
+ } else {
+ TF_RET_CHECK(callsite.context == CallContext::kSequential);
+ context_to_add = node->context();
+ }
+ CallContext new_context =
+ UnionContexts(context_to_add, callee_node->context());
+
+ if (new_context != callee_node->context()) {
+ // Context of computation has been changed so add node to worklist.
+ callee_node->set_context(new_context);
+ worklist.push(callee_node);
+ }
+ }
+ }
+
+ // No node should have a kNone calling context.
+ for (const std::unique_ptr<HloComputation>& computation :
+ module_->computations()) {
+ TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get()));
+ TF_RET_CHECK(node->context() != CallContext::kNone);
+ }
+ return Status::OK();
+}
+
+/* static */
+StatusOr<CallGraph> CallGraph::Build(const HloModule* module) {
+ CallGraph call_graph(module);
+
+ // Construct nodes of the call graph and populate the callsites.
+ for (const std::unique_ptr<HloComputation>& computation :
+ module->computations()) {
+ auto it_added = call_graph.node_indices_.insert(
+ {computation.get(), call_graph.nodes_.size()});
+ // All computation should be unique, so the computation should not already
+ // exist in the map.
+ TF_RET_CHECK(it_added.second);
+ call_graph.nodes_.emplace_back(computation.get());
+
+ // Add all callsites in this computation.
+ for (const std::unique_ptr<HloInstruction>& instruction :
+ computation->instructions()) {
+ call_graph.nodes_.back().AddCallSitesInInstruction(instruction.get());
+ }
+ }
+
+ // Add caller callsites to each node.
+ for (const std::unique_ptr<HloComputation>& computation :
+ module->computations()) {
+ TF_ASSIGN_OR_RETURN(CallGraphNode * caller_node,
+ call_graph.GetNode(computation.get()));
+ for (const CallSite& callsite : caller_node->callsites()) {
+ // Add caller callsites.
+ TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node,
+ call_graph.GetNode(callsite.called_computation));
+ callee_node->AddCallerCallSite(callsite);
+ }
+ }
+
+ TF_RETURN_IF_ERROR(call_graph.SetCallContexts());
+
+ XLA_VLOG_LINES(1, call_graph.ToString());
+
+ return std::move(call_graph);
+}
+
+string CallGraph::ToString() const {
+ string out;
+ Appendf(&out, "Call graph for module %s:\n", module_->name().c_str());
+ for (const CallGraphNode& node : nodes()) {
+ Appendf(&out, "Computation %s:\n", node.computation()->name().c_str());
+ Appendf(&out, " calls:\n");
+ for (const HloComputation* callee : node.callees()) {
+ Appendf(&out, " %s\n", callee->name().c_str());
+ }
+ Appendf(&out, " called by:\n");
+ for (const HloComputation* caller : node.callers()) {
+ Appendf(&out, " %s\n", caller->name().c_str());
+ }
+ Appendf(&out, " callsites:\n");
+ for (const CallSite& callsite : node.callsites()) {
+ Appendf(&out, " %s\n", callsite.ToString().c_str());
+ }
+ }
+ return out;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h
new file mode 100644
index 0000000000..f9291684bf
--- /dev/null
+++ b/tensorflow/compiler/xla/service/call_graph.h
@@ -0,0 +1,175 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Call graph for an HLO module.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
+
+#include <ostream>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+// The context in which a computation is called by another computation.
+enum class CallContext {
+ // In a parallel contex the computation is applied to each element of the
+ // array argument(s). kMap and kReduce instructions call computations in
+ // parallel context.
+ kParallel,
+
+ // In a sequential context the computation is applied to the entire argument
+ // shape(s). kCall and kWhile (body and condition) call computations in
+ // sequential context.
+ kSequential,
+
+ // A computation is called from both a parallel and sequential context.
+ kBoth,
+
+ // During call graph construction kNone is used to indicate that the context
+ // has not been determined. This is the top value for the context
+ // lattice. After construction, no call sites or call graph nodes should have
+ // this value.
+ kNone
+};
+
+string CallContextToString(CallContext context);
+std::ostream& operator<<(std::ostream& out, const CallContext& context);
+
+// Represents an instruction calling a particular computation in an HLO
+// module. Some instructions such as kWhile can call more than one computation
+// and may be represented with more than one CallSite, one for each computation
+// called.
+struct CallSite {
+ // The calling instruction.
+ HloInstruction* instruction;
+
+ // The computation the instruction is calling.
+ HloComputation* called_computation;
+
+ // The context in which the computation is called.
+ CallContext context;
+
+ string ToString() const;
+};
+
+// A node in the call graph representing an HLO computation.
+class CallGraphNode {
+ public:
+ CallGraphNode(HloComputation* computation);
+
+ // Return the computation represented by this call graph node.
+ HloComputation* computation() const { return computation_; }
+
+ // Return the call sites in this computation. These are the instructions in
+ // this computation which call other computations.
+ const std::vector<CallSite>& callsites() const { return callsites_; }
+
+ // Return the computations called by this computation.
+ const std::vector<HloComputation*>& callees() const { return callees_; }
+
+ // Return the call sites in other computations which call this computation.
+ const std::vector<CallSite>& caller_callsites() const {
+ return caller_callsites_;
+ }
+
+ // Return the computations which call this computation.
+ const std::vector<HloComputation*>& callers() const { return callers_; }
+
+ // Return or set the context in which this computation is called.
+ CallContext context() const { return context_; }
+ void set_context(CallContext value) { context_ = value; }
+
+ // Add a callsite which calls this computation. Updates callers to include the
+ // calling computation.
+ void AddCallerCallSite(const CallSite& caller_callsite);
+
+ // Add a call site to this computation. Updates callees to include the called
+ // computation.
+ void AddCallSite(const CallSite& callsite);
+
+ // Add all the call sites (if any) for this instruction. Instruction must be
+ // an instruction in this node's computation.
+ void AddCallSitesInInstruction(HloInstruction* instruction);
+
+ string ToString() const;
+
+ private:
+ // Computation represented by this call graph node.
+ HloComputation* computation_;
+
+ // The computations called by this computation. The vector is used for a
+ // stable ordering and the set enables fast membership testing.
+ std::vector<HloComputation*> callees_;
+ std::unordered_set<HloComputation*> callee_set_;
+
+ // The computations which call this computation. The vector is used for a
+ // stable ordering and the set enables fast membership testing.
+ std::vector<HloComputation*> callers_;
+ std::unordered_set<HloComputation*> caller_set_;
+
+ // The call sites in this computation
+ std::vector<CallSite> callsites_;
+
+ // The call sites in other computations which call this computation.
+ std::vector<CallSite> caller_callsites_;
+
+ // The context in which this computation is called.
+ CallContext context_ = CallContext::kNone;
+};
+
+// The call graph for an HLO module. The graph includes a node for each
+// computation in the module.
+class CallGraph {
+ public:
+ // Build and return a call graph for the given HLO module.
+ static StatusOr<CallGraph> Build(const HloModule* module);
+
+ // Public default constructor required for StatusOr<CallGraph>.
+ CallGraph() = default;
+
+ // Return the node associated with the given computation.
+ StatusOr<const CallGraphNode*> GetNode(
+ const HloComputation* computation) const;
+ StatusOr<CallGraphNode*> GetNode(const HloComputation* computation);
+
+ // Return the vector of all nodes in the call graph.
+ const std::vector<CallGraphNode>& nodes() const { return nodes_; }
+
+ string ToString() const;
+
+ private:
+ CallGraph(const HloModule* module);
+
+ // Sets the call contexts for every node in the graph.
+ Status SetCallContexts();
+
+ const HloModule* module_ = nullptr;
+
+ // Vector of all nodes in the call graph.
+ std::vector<CallGraphNode> nodes_;
+
+ // Map from HLO computation to the index of the corresponding call graph node
+ // in nodes_.
+ std::unordered_map<const HloComputation*, int64> node_indices_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
new file mode 100644
index 0000000000..c63a1bef4e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -0,0 +1,290 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/call_graph.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class CallGraphTest : public HloTestBase {
+ protected:
+ // Build and return a trivial computation taking and returning a scalar.
+ std::unique_ptr<HloComputation> MakeScalarComputation() {
+ HloComputation::Builder builder(TestName() + ".ScalarComputation");
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, kScalarShape, "param0"));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0));
+ return builder.Build();
+ }
+
+ // Build and return a computation which takes a scalar and maps (kMap) the
+ // given computation to the value 'callsites' number of times.
+ std::unique_ptr<HloComputation> MakeMappingComputation(
+ HloComputation* map_computation, int64 callsites) {
+ HloComputation::Builder builder(TestName() + ".MappingComputation");
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, kScalarShape, "param0"));
+ HloInstruction* last_value = param0;
+ for (int64 i = 0; i < callsites; ++i) {
+ last_value = builder.AddInstruction(HloInstruction::CreateMap(
+ kScalarShape, {last_value}, map_computation));
+ }
+ return builder.Build();
+ }
+
+ // Build and return a computation which takes a scalar and calls (kCall) the
+ // given computation with value 'callsites' number of times.
+ std::unique_ptr<HloComputation> MakeCallingComputation(
+ HloComputation* map_computation, int64 callsites) {
+ HloComputation::Builder builder(TestName() + ".CallingComputation");
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, kScalarShape, "param0"));
+ HloInstruction* last_value = param0;
+ for (int64 i = 0; i < callsites; ++i) {
+ last_value = builder.AddInstruction(HloInstruction::CreateCall(
+ kScalarShape, {last_value}, map_computation));
+ }
+ return builder.Build();
+ }
+
+ // Build and return a computation which takes a scalar and returns a PRED
+ // value.
+ std::unique_ptr<HloComputation> MakeConditionComputation() {
+ HloComputation::Builder builder(TestName() + ".ConditionComputation");
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, kScalarShape, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
+ return builder.Build();
+ }
+
+ const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
+};
+
+TEST_F(CallGraphTest, SingletonComputation) {
+ // Test the call graph of a module with a single computation.
+ HloModule module(TestName());
+ HloComputation* computation =
+ module.AddEntryComputation(MakeScalarComputation());
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module));
+ EXPECT_EQ(1, call_graph.nodes().size());
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* node,
+ call_graph.GetNode(computation));
+ EXPECT_EQ(computation, node->computation());
+ EXPECT_TRUE(node->callsites().empty());
+ EXPECT_TRUE(node->callees().empty());
+ EXPECT_TRUE(node->caller_callsites().empty());
+ EXPECT_TRUE(node->callers().empty());
+ EXPECT_EQ(CallContext::kSequential, node->context());
+}
+
+TEST_F(CallGraphTest, UnreachableComputation) {
+ // Test the call graph of a module with an entry computation and an
+ // unreachable computation.
+ HloModule module(TestName());
+ HloComputation* entry_computation =
+ module.AddEntryComputation(MakeScalarComputation());
+ HloComputation* unreachable_computation =
+ module.AddEmbeddedComputation(MakeScalarComputation());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module));
+ EXPECT_EQ(2, call_graph.nodes().size());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node,
+ call_graph.GetNode(entry_computation));
+ EXPECT_EQ(entry_computation, entry_node->computation());
+ EXPECT_EQ(CallContext::kSequential, entry_node->context());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* unreachable_node,
+ call_graph.GetNode(unreachable_computation));
+ EXPECT_EQ(unreachable_computation, unreachable_node->computation());
+ EXPECT_EQ(CallContext::kSequential, unreachable_node->context());
+}
+
+TEST_F(CallGraphTest, ParallelComputation) {
+ // Test a call graph of a module with an entry computation which calls another
+ // computation in a parallel context via kMap.
+ HloModule module(TestName());
+ HloComputation* map_computation =
+ module.AddEmbeddedComputation(MakeScalarComputation());
+ HloComputation* entry_computation = module.AddEmbeddedComputation(
+ MakeMappingComputation(map_computation, /*callsites=*/5));
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module));
+ EXPECT_EQ(2, call_graph.nodes().size());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node,
+ call_graph.GetNode(entry_computation));
+ EXPECT_EQ(entry_computation, entry_node->computation());
+ EXPECT_EQ(CallContext::kSequential, entry_node->context());
+ EXPECT_EQ(5, entry_node->callsites().size());
+ EXPECT_EQ(1, entry_node->callees().size());
+ EXPECT_TRUE(entry_node->caller_callsites().empty());
+ EXPECT_TRUE(entry_node->callers().empty());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* map_node,
+ call_graph.GetNode(map_computation));
+ EXPECT_EQ(map_computation, map_node->computation());
+ EXPECT_EQ(CallContext::kParallel, map_node->context());
+ EXPECT_TRUE(map_node->callsites().empty());
+ EXPECT_TRUE(map_node->callees().empty());
+ EXPECT_EQ(5, map_node->caller_callsites().size());
+ EXPECT_EQ(1, map_node->callers().size());
+}
+
+TEST_F(CallGraphTest, SequentialComputations) {
+ // Test a call graph of a module with an entry computation which calls another
+ // computation in a sequential context via kCall.
+ HloModule module(TestName());
+ HloComputation* called_computation =
+ module.AddEmbeddedComputation(MakeScalarComputation());
+ HloComputation* entry_computation = module.AddEmbeddedComputation(
+ MakeCallingComputation(called_computation, /*callsites=*/3));
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module));
+ EXPECT_EQ(2, call_graph.nodes().size());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node,
+ call_graph.GetNode(entry_computation));
+ EXPECT_EQ(entry_computation, entry_node->computation());
+ EXPECT_EQ(CallContext::kSequential, entry_node->context());
+ EXPECT_EQ(3, entry_node->callsites().size());
+ EXPECT_EQ(1, entry_node->callees().size());
+ EXPECT_TRUE(entry_node->caller_callsites().empty());
+ EXPECT_TRUE(entry_node->callers().empty());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* called_node,
+ call_graph.GetNode(called_computation));
+ EXPECT_EQ(called_computation, called_node->computation());
+ EXPECT_EQ(CallContext::kSequential, called_node->context());
+ EXPECT_TRUE(called_node->callsites().empty());
+ EXPECT_TRUE(called_node->callees().empty());
+ EXPECT_EQ(3, called_node->caller_callsites().size());
+ EXPECT_EQ(1, called_node->callers().size());
+}
+
+TEST_F(CallGraphTest, ContextBothComputations) {
+ // Test a call graph of a module with an entry computation which calls another
+ // computation in both a parallel and sequential context.
+ HloModule module(TestName());
+ HloComputation* subcomputation =
+ module.AddEmbeddedComputation(MakeScalarComputation());
+
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, kScalarShape, "param0"));
+ HloInstruction* call = builder.AddInstruction(
+ HloInstruction::CreateCall(kScalarShape, {param0}, subcomputation));
+ HloInstruction* map = builder.AddInstruction(
+ HloInstruction::CreateMap(kScalarShape, {call}, subcomputation));
+ HloComputation* entry_computation =
+ module.AddEmbeddedComputation(builder.Build());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module));
+ EXPECT_EQ(2, call_graph.nodes().size());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node,
+ call_graph.GetNode(entry_computation));
+ EXPECT_EQ(entry_computation, entry_node->computation());
+ EXPECT_EQ(2, entry_node->callsites().size());
+
+ const CallSite& call_callsite = entry_node->callsites()[0];
+ EXPECT_EQ(call, call_callsite.instruction);
+ EXPECT_EQ(subcomputation, call_callsite.called_computation);
+ EXPECT_EQ(CallContext::kSequential, call_callsite.context);
+
+ const CallSite& map_callsite = entry_node->callsites()[1];
+ EXPECT_EQ(map, map_callsite.instruction);
+ EXPECT_EQ(subcomputation, map_callsite.called_computation);
+ EXPECT_EQ(CallContext::kParallel, map_callsite.context);
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* sub_node,
+ call_graph.GetNode(subcomputation));
+ EXPECT_EQ(CallContext::kBoth, sub_node->context());
+}
+
+TEST_F(CallGraphTest, ComplexGraph) {
+ // Test a call graph of a module with several computation called in various
+ // contexts. The call graph looks like:
+ //
+ // entry
+ // / |
+ // a |
+ // / | \ |
+ // b | cond
+ // \ |
+ // c
+ //
+ // Calls are made via kCall, kWhile, and kMap instructions.
+ HloModule module(TestName());
+ HloComputation* cond_computation =
+ module.AddEmbeddedComputation(MakeConditionComputation());
+ HloComputation* c_computation =
+ module.AddEmbeddedComputation(MakeScalarComputation());
+ HloComputation* b_computation = module.AddEmbeddedComputation(
+ MakeMappingComputation(c_computation, /*callsites=*/1));
+
+ HloComputation* computation_a;
+ {
+ HloComputation::Builder builder(TestName() + ".a");
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, kScalarShape, "param0"));
+ HloInstruction* call = builder.AddInstruction(
+ HloInstruction::CreateCall(kScalarShape, {param0}, c_computation));
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ kScalarShape, cond_computation, b_computation, call));
+ computation_a = module.AddEmbeddedComputation(builder.Build());
+ }
+
+ HloComputation* entry_computation;
+ {
+ HloComputation::Builder builder(TestName() + ".entry");
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, kScalarShape, "param0"));
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ kScalarShape, cond_computation, computation_a, param0));
+ entry_computation = module.AddEntryComputation(builder.Build());
+ }
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraph call_graph, CallGraph::Build(&module));
+ EXPECT_EQ(5, call_graph.nodes().size());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node,
+ call_graph.GetNode(entry_computation));
+ // Entry computation has one while instruction (two callsites).
+ EXPECT_EQ(2, entry_node->callsites().size());
+ EXPECT_EQ(CallContext::kSequential, entry_node->context());
+
+ TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* c_node,
+ call_graph.GetNode(c_computation));
+ EXPECT_TRUE(c_node->callsites().empty());
+ EXPECT_EQ(2, c_node->callers().size());
+ EXPECT_EQ(CallContext::kBoth, c_node->context());
+}
+
+} // namespace
+} // namespace xla