diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_graph_dumper.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 26 |
1 files changed, 5 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 04b3059fb1..fd162622ce 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -761,22 +761,12 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { auto stringify_constant = [](const HloInstruction* constant) { - const auto& shape = constant->shape(); - - // Print the literal value of constants with <= K elements. - optional<int64> elem_count; - if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { - elem_count = 1; - for (int64 dim : shape.dimensions()) { - *elem_count *= dim; - } - } - if (elem_count.has_value() && *elem_count <= 8) { - return Printf("%s (%s)", constant->literal().ToString(), + if (ShapeUtil::IsEffectiveScalar(constant->shape())) { + auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( + constant->shape(), /*linear_index=*/0); + return Printf("%s (%s)", constant->literal().GetAsString(elem_idx), ShapeUtil::HumanString(constant->shape())); } - - // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { constant_name = constant->name(); @@ -943,9 +933,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kFusion: return kGray; case HloOpcode::kSend: - case HloOpcode::kSendDone: case HloOpcode::kRecv: - case HloOpcode::kRecvDone: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: @@ -1039,9 +1027,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { ? "" : StrCat("stride=", VectorString(instr->slice_strides())); case HloOpcode::kSend: - case HloOpcode::kSendDone: case HloOpcode::kRecv: - case HloOpcode::kRecvDone: return StrCat("channel_id=", instr->channel_id()); default: return ""; @@ -1303,9 +1289,7 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { auto is_displayed = [&](const HloInstruction* instr) { // Constants are displayed inline with their users; they're never omitted. - // Nodes in subcomputations are always shown. - return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant || - instr->parent() != root->parent(); + return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant; }; // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we |