aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-07 00:46:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-07 02:15:41 -0700
commitbea7cabfd0c32638d4c102c636270cbfecc6665b (patch)
tree9c8ffffb9e705dadd11f58956bbd5e89fa257e5c
parent2dca3420ff04a4bd1e520e895845c9b7fda22972 (diff)
Add a tool that converts HLO computations to tensorflow GraphDef which can be visualized on Tensorboard.
This CL defines basic tensorflow::OpDef for each HLO instruction/node. More attributes (e.g. shapes, colors) will be added in the future. Change: 152477918
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h3
-rw-r--r--tensorflow/compiler/xla/tools/BUILD44
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc139
-rw-r--r--tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc129
-rw-r--r--tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h59
-rw-r--r--tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc107
7 files changed, 494 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index 616b239a93..ceb0cdaa31 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -165,4 +165,17 @@ bool HloOpcodeIsComparison(HloOpcode opcode) {
}
}
+bool HloOpcodeIsVariadic(HloOpcode opcode) {
+ switch (opcode) {
+ case HloOpcode::kCall:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kFusion:
+ case HloOpcode::kMap:
+ case HloOpcode::kTuple:
+ return true;
+ default:
+ return false;
+ }
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 978ed5e79b..e2cdbfdfa7 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -104,6 +104,9 @@ inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) {
// Returns true iff the given opcode is a comparison operation.
bool HloOpcodeIsComparison(HloOpcode opcode);
+// Returns true iff the given opcode has variadic operands.
+bool HloOpcodeIsVariadic(HloOpcode opcode);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 46eab7f02b..1c4ca44631 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -176,6 +176,50 @@ cc_binary(
],
)
+cc_library(
+ name = "hlo_tfgraph_builder",
+ srcs = ["hlo_tfgraph_builder.cc"],
+ hdrs = ["hlo_tfgraph_builder.h"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "hlo_tfgraph_builder_test",
+ srcs = ["hlo_tfgraph_builder_test.cc"],
+ deps = [
+ ":hlo_tfgraph_builder",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_binary(
+ name = "dumped_computation_to_tf_graphdef",
+ srcs = ["dumped_computation_to_tf_graphdef.cc"],
+ deps = [
+ ":hlo_tfgraph_builder",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service",
+ "//tensorflow/compiler/xla/service:hlo_graph_dumper",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
# -----------------------------------------------------------------------------
filegroup(
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
new file mode 100644
index 0000000000..1aa769ee5a
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
@@ -0,0 +1,139 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Usage: dumped_computation_to_tf_graph \
+// --output_dir=/tmp/graphs/ some_binary_snapshot_proto*
+//
+// Dumps a tensorflow GraphDef in text format for a snapshot computation. The
+// dumped graph is an HLO computation with HLO instructions as nodes and can be
+// visualized on Tensorboard. Upload the dumped files on Tensorboard.
+//
+// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// ServiceInterface::SnapshotComputation to disk.
+
+#include <stdio.h>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/service.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::Env;
+using tensorflow::io::JoinPath;
+using tensorflow::strings::StrAppend;
+
+namespace xla {
+namespace tools {
+namespace {
+
+// Dumps all computations in the module to the given directory.
+void DumpTfGraph(const HloModule& module, const string& directory_path) {
+ Env* env = Env::Default();
+ TF_CHECK_OK(env->RecursivelyCreateDir(directory_path));
+ string fname = module.name();
+ std::replace(fname.begin(), fname.end(), '/', '_');
+ // Since the file name will be used as the top-level scope name, clean it up
+ // to make it a valid scope name.
+ CleanNodeName(&fname);
+ StrAppend(&fname, ".pbtxt");
+ string path = JoinPath(directory_path, fname);
+ HloTfGraphBuilder builder;
+ TF_CHECK_OK(builder.AddComputation(*module.entry_computation()));
+ std::cout << "Dumping " << module.name() << " to " << path << std::endl;
+ TF_CHECK_OK(WriteTextProto(env, path, builder.GetGraphDef()));
+}
+
+} // namespace
+
+void RealMain(tensorflow::gtl::ArraySlice<char*> args,
+ const string& output_dir) {
+ LocalClient* client = ClientLibrary::LocalClientOrDie();
+ // To avoid adding a new flag, use local service and lower the computations
+ // locally.
+ LocalService* local_service =
+ ClientLibrary::GetXlaService(client->platform());
+ // Build HloModule for each Computation and dump to file.
+ for (char* arg : args) {
+ SessionModule session_module;
+ TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
+ &session_module));
+ auto computation_status = client->LoadSnapshot(session_module);
+ if (!computation_status.ok()) {
+ fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
+ computation_status.status().ToString().c_str());
+ continue;
+ }
+ Computation computation = computation_status.ConsumeValueOrDie();
+
+ StatusOr<UserComputation*> user_computation_status =
+ local_service->computation_tracker().Resolve(computation.handle());
+ if (!user_computation_status.ok()) {
+ fprintf(stderr,
+ "failed to resolve computation to UserComputation %s: %s\n", arg,
+ user_computation_status.status().ToString().c_str());
+ continue;
+ }
+
+ auto* user_computation = user_computation_status.ValueOrDie();
+ StatusOr<std::unique_ptr<HloModule>> module_status =
+ local_service->computation_tracker().BuildHloModule(
+ user_computation->GetVersionedHandle());
+
+ if (!module_status.ok()) {
+ fprintf(stderr, "failed to build HloModule %s: %s\n", arg,
+ module_status.status().ToString().c_str());
+ continue;
+ }
+
+ DumpTfGraph(*module_status.ValueOrDie(), output_dir);
+ }
+}
+
+} // namespace tools
+} // namespace xla
+
+int main(int argc, char** argv) {
+ string output_dir = "";
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("output_dir", &output_dir,
+ "Directory to write GraphDef data to."),
+ };
+
+ string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_ok || output_dir.empty()) {
+ LOG(QFATAL) << usage;
+ }
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ args.pop_front(); // Pop off the binary name, argv[0]
+ xla::tools::RealMain(args, output_dir);
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc
new file mode 100644
index 0000000000..8e04ea3ae4
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc
@@ -0,0 +1,129 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+LIcensed under the Apache License, Version 2.0 (the "License");
+You may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+using ::tensorflow::GraphDef;
+using ::tensorflow::NodeDef;
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
+
+namespace xla {
+namespace tools {
+
+static string GetOpDefName(const HloInstruction* instruction) {
+ string name = StrCat("hlo-", HloOpcodeString(instruction->opcode()));
+ tensorflow::str_util::TitlecaseString(&name, "-");
+ name.erase(std::remove(name.begin(), name.end(), '-'), name.end());
+
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ string fusion_name = ToString(instruction->fusion_kind());
+ StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1));
+ }
+ return name;
+}
+
+void CleanNodeName(string* name) {
+ name->erase(std::remove(name->begin(), name->end(), '%'), name->end());
+ const string chars_to_replace = "<>[]";
+ auto pred = [&](char c) {
+ return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) !=
+ chars_to_replace.end();
+ };
+ std::replace_if(name->begin(), name->end(), pred, '_');
+}
+
+Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) {
+ LOG(INFO) << "Adding computation " << computation.name();
+ for (auto embedded : computation.MakeEmbeddedComputationsList()) {
+ LOG(INFO) << "Adding embedded computation " << embedded->name();
+ for (auto& instruction : embedded->instructions()) {
+ TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
+ }
+ }
+ for (auto& instruction : computation.instructions()) {
+ TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
+ }
+ return Status::OK();
+}
+
+const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; }
+
+const string& HloTfGraphBuilder::GetNodeNameForInstruction(
+ const HloInstruction* instruction) {
+ if (ContainsKey(instruction_to_node_name_, instruction)) {
+ return instruction_to_node_name_[instruction];
+ }
+ // If an instruction is fused, put it in the subgraph of the fusion;
+ // otherwise, put it in the computation subgraph.
+ string node_name =
+ instruction->IsFused()
+ ? GetNodeNameForInstruction(instruction->fusion_instruction())
+ : instruction->parent()->name();
+ string instruction_name = instruction->name();
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ StrAppend(&instruction_name, ".", instruction->parameter_number());
+ }
+ StrAppend(&node_name, "/", instruction_name);
+ CleanNodeName(&node_name);
+ auto ret =
+ instruction_to_node_name_.insert(std::make_pair(instruction, node_name));
+ CHECK(ret.second);
+ return ret.first->second;
+}
+
+// TODO(b/36987876): Add more attribute information e.g. shapes, dimensions etc.
+void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
+ NodeDef* node_def) const {
+ // Set the number of arguments for instructions that have variadic operands.
+ if (HloOpcodeIsVariadic(instruction->opcode())) {
+ tensorflow::AttrValue attr_value;
+ attr_value.set_i(instruction->operands().size());
+ (*node_def->mutable_attr())["ArgNum"] = attr_value;
+ }
+}
+
+Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
+ if (!visited_instructions_.insert(instruction).second) {
+ // Skip instructions that have already been added.
+ return Status::OK();
+ }
+
+ NodeDef* node_def = graph_def_.add_node();
+ node_def->set_name(GetNodeNameForInstruction(instruction));
+ node_def->set_op(GetOpDefName(instruction));
+ SetNodeAttrs(instruction, node_def);
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ for (auto& fused_instruction : instruction->fused_instructions()) {
+ TF_RETURN_IF_ERROR(AddInstruction(fused_instruction.get()));
+ }
+ }
+ // Add all edges including control edges.
+ for (unsigned i = 0; i < instruction->operands().size(); ++i) {
+ *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i));
+ }
+ // Called computations are control dependencies.
+ for (const auto* called_computation : instruction->called_computations()) {
+ *node_def->add_input() = StrCat(
+ "^", GetNodeNameForInstruction(called_computation->root_instruction()));
+ }
+ return Status::OK();
+}
+
+} // namespace tools
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h
new file mode 100644
index 0000000000..3052eae113
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace xla {
+namespace tools {
+
+// This constructs a tensorflow graph for HLO computations.
+class HloTfGraphBuilder {
+ public:
+ // Adds a computation to the graph.
+ Status AddComputation(const HloComputation& computation);
+
+ const tensorflow::GraphDef& GetGraphDef() const;
+
+ private:
+ // Gets the node name of an instruction. The node name is hierarchical. For
+ // example, if an instruction is fused, it will be put in a subgraph of the
+ // fusion instruction.
+ const string& GetNodeNameForInstruction(const HloInstruction* instruction);
+
+ void SetNodeAttrs(const HloInstruction* instruction,
+ tensorflow::NodeDef* node_def) const;
+
+ Status AddInstruction(const HloInstruction* instruction);
+
+ tensorflow::GraphDef graph_def_;
+ // This records instructions that have been visited.
+ std::unordered_set<const HloInstruction*> visited_instructions_;
+ // A cache that maps instruction to the node name.
+ std::unordered_map<const HloInstruction*, string> instruction_to_node_name_;
+};
+
+// Cleans the node name to make it a valid name in a tensorflow graph.
+void CleanNodeName(string* name);
+
+} // namespace tools
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc
new file mode 100644
index 0000000000..57144d0385
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace tools {
+namespace {
+
+using ::tensorflow::GraphDef;
+
+class HloTfGraphBuilderTest : public HloTestBase {
+ protected:
+ HloTfGraphBuilderTest() {}
+ HloTfGraphBuilder generator_;
+
+ // Create a computation which takes a scalar and returns its negation.
+ std::unique_ptr<HloComputation> CreateNegateComputation() {
+ auto builder = HloComputation::Builder("Negate");
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
+ return builder.Build();
+ }
+
+ // Creates a computation which calls map with the given computation.
+ std::unique_ptr<HloComputation> CreateMapComputation(
+ HloComputation* map_computation) {
+ auto builder = HloComputation::Builder("Map");
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(r0f32_, {param}, map_computation));
+ return builder.Build();
+ }
+ Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+};
+
+TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) {
+ auto negate_computation = CreateNegateComputation();
+ TF_CHECK_OK(generator_.AddComputation(*negate_computation));
+ GraphDef graph_def = generator_.GetGraphDef();
+ EXPECT_EQ(graph_def.node_size(), 2);
+ EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0");
+ EXPECT_EQ(graph_def.node(0).op(), "HloParameter");
+ EXPECT_EQ(graph_def.node(1).name(), "Negate/negate");
+ EXPECT_EQ(graph_def.node(1).op(), "HloNegate");
+ EXPECT_EQ(graph_def.node(1).input_size(), 1);
+ EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0");
+}
+
+TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) {
+ auto builder = HloComputation::Builder("GE");
+ auto param_1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ auto param_2 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32_, "param1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2));
+ TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
+ GraphDef graph_def = generator_.GetGraphDef();
+ EXPECT_EQ(graph_def.node_size(), 3);
+ EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0");
+ EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1");
+ EXPECT_EQ(graph_def.node(2).input_size(), 2);
+ EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to");
+ EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo");
+}
+
+TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) {
+ // Create computations with a diamond-shaped callgraph.
+ auto negate_computation = CreateNegateComputation();
+ auto map1_computation = CreateMapComputation(negate_computation.get());
+ auto map2_computation = CreateMapComputation(negate_computation.get());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ auto map1 = builder.AddInstruction(
+ HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get()));
+ auto map2 = builder.AddInstruction(
+ HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get()));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
+ auto computation = builder.Build();
+ TF_CHECK_OK(generator_.AddComputation(*computation));
+ EXPECT_GT(generator_.GetGraphDef().node_size(), 0);
+}
+
+} // namespace
+} // namespace tools
+} // namespace xla