aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-11 16:44:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 16:47:38 -0700
commit734ce1d8e5991c8e7b243b0bab37c074864c0eea (patch)
treecf3f5203935150883e7afd853c441434b00fa58a /tensorflow/compiler/xla/service/hlo_graph_dumper.cc
parent95345968a2445c75eaeaa22659b7e574aafe25a7 (diff)
Split out HloConstantInstruction and HloTraceInstruction as subclasses from HloInstruction.
PiperOrigin-RevId: 200135616
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_graph_dumper.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc20
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 05aab9a2cd..28fc6c4209 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -28,6 +28,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -723,17 +725,14 @@ string HloDotDumper::DumpRootTag() {
to_id, node_body, node_shape, NodeColorAttributes(color));
}
-static const HloInstruction* TryGetFusionParameterConstant(
+static const HloConstantInstruction* TryGetFusionParameterConstant(
const HloInstruction* instr) {
if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
return nullptr;
}
const HloInstruction* fusion = instr->parent()->FusionInstruction();
const HloInstruction* operand = fusion->operand(instr->parameter_number());
- if (operand->opcode() == HloOpcode::kConstant) {
- return operand;
- }
- return nullptr;
+ return DynCast<HloConstantInstruction>(operand);
}
bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
@@ -826,7 +825,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeInlinedOperands(
const HloInstruction* instr) {
- auto stringify_constant = [](const HloInstruction* constant) {
+ auto stringify_constant = [](const HloConstantInstruction* constant) {
const auto& shape = constant->shape();
// If the shape has a dimension of size zero, print it as e.g.
@@ -845,7 +844,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
*elem_count *= dim;
}
}
- if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
+ if (elem_count.has_value() && *elem_count <= 8) {
return Printf("%s (%s)", constant->literal().ToString(),
ShapeUtil::HumanString(constant->shape()));
}
@@ -864,9 +863,10 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
std::vector<string> lines;
for (int64 i = 0; i < instr->operand_count(); ++i) {
const HloInstruction* operand = instr->operand(i);
+ const auto* constant_operand = DynCast<HloConstantInstruction>(operand);
optional<string> operand_str;
- if (operand->opcode() == HloOpcode::kConstant) {
- operand_str = stringify_constant(operand);
+ if (constant_operand != nullptr) {
+ operand_str = stringify_constant(constant_operand);
} else if (ShouldMergeIntoUsers(operand)) {
// Special case: If the operand is a parameter to a fusion node and it
// always has a constant value, display it like a regular constant.
@@ -874,7 +874,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
// For other parameters, use the parameter number rather than the proper
// name, because that's generally how people think of the node.
if (operand->opcode() == HloOpcode::kParameter) {
- if (const HloInstruction* constant =
+ if (const HloConstantInstruction* constant =
TryGetFusionParameterConstant(operand)) {
operand_str = stringify_constant(constant);
} else {