diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-11 16:44:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-11 16:47:38 -0700 |
commit | 734ce1d8e5991c8e7b243b0bab37c074864c0eea (patch) | |
tree | cf3f5203935150883e7afd853c441434b00fa58a /tensorflow/compiler/xla/service/hlo_graph_dumper.cc | |
parent | 95345968a2445c75eaeaa22659b7e574aafe25a7 (diff) |
Split out HloConstantInstruction and HloTraceInstruction as subclasses from HloInstruction.
PiperOrigin-RevId: 200135616
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_graph_dumper.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 05aab9a2cd..28fc6c4209 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -28,6 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -723,17 +725,14 @@ string HloDotDumper::DumpRootTag() { to_id, node_body, node_shape, NodeColorAttributes(color)); } -static const HloInstruction* TryGetFusionParameterConstant( +static const HloConstantInstruction* TryGetFusionParameterConstant( const HloInstruction* instr) { if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) { return nullptr; } const HloInstruction* fusion = instr->parent()->FusionInstruction(); const HloInstruction* operand = fusion->operand(instr->parameter_number()); - if (operand->opcode() == HloOpcode::kConstant) { - return operand; - } - return nullptr; + return DynCast<HloConstantInstruction>(operand); } bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { @@ -826,7 +825,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { - auto stringify_constant = [](const HloInstruction* constant) { + auto stringify_constant = [](const HloConstantInstruction* constant) { const auto& shape = constant->shape(); // If the shape has a dimension of size zero, print it as e.g. @@ -845,7 +844,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { + if (elem_count.has_value() && *elem_count <= 8) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -864,9 +863,10 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( std::vector<string> lines; for (int64 i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); + const auto* constant_operand = DynCast<HloConstantInstruction>(operand); optional<string> operand_str; - if (operand->opcode() == HloOpcode::kConstant) { - operand_str = stringify_constant(operand); + if (constant_operand != nullptr) { + operand_str = stringify_constant(constant_operand); } else if (ShouldMergeIntoUsers(operand)) { // Special case: If the operand is a parameter to a fusion node and it // always has a constant value, display it like a regular constant. @@ -874,7 +874,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // For other parameters, use the parameter number rather than the proper // name, because that's generally how people think of the node. if (operand->opcode() == HloOpcode::kParameter) { - if (const HloInstruction* constant = + if (const HloConstantInstruction* constant = TryGetFusionParameterConstant(operand)) { operand_str = stringify_constant(constant); } else { |