aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_graph_dumper.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc26
1 files changed, 5 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 04b3059fb1..fd162622ce 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -761,22 +761,12 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeInlinedOperands(
const HloInstruction* instr) {
auto stringify_constant = [](const HloInstruction* constant) {
- const auto& shape = constant->shape();
-
- // Print the literal value of constants with <= K elements.
- optional<int64> elem_count;
- if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) {
- elem_count = 1;
- for (int64 dim : shape.dimensions()) {
- *elem_count *= dim;
- }
- }
- if (elem_count.has_value() && *elem_count <= 8) {
- return Printf("%s (%s)", constant->literal().ToString(),
+ if (ShapeUtil::IsEffectiveScalar(constant->shape())) {
+ auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
+ constant->shape(), /*linear_index=*/0);
+ return Printf("%s (%s)", constant->literal().GetAsString(elem_idx),
ShapeUtil::HumanString(constant->shape()));
}
-
- // Otherwise, print e.g. "%constant.42 (s32[100])".
string constant_name;
if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) {
constant_name = constant->name();
@@ -943,9 +933,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kFusion:
return kGray;
case HloOpcode::kSend:
- case HloOpcode::kSendDone:
case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum:
@@ -1039,9 +1027,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
? ""
: StrCat("stride=", VectorString(instr->slice_strides()));
case HloOpcode::kSend:
- case HloOpcode::kSendDone:
case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
return StrCat("channel_id=", instr->channel_id());
default:
return "";
@@ -1303,9 +1289,7 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
auto is_displayed = [&](const HloInstruction* instr) {
// Constants are displayed inline with their users; they're never omitted.
- // Nodes in subcomputations are always shown.
- return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant ||
- instr->parent() != root->parent();
+ return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant;
};
// Make a second pass over 'nodes' to fix up the NodeFilterResults now that we