aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-04-18 15:30:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 15:33:28 -0700
commit8c66f2223078dca765e7817f26f66e61fe819715 (patch)
treed191e8b2b35748faa594380254f1be09d0fcac23 /tensorflow/compiler
parent80f60ea37ed77b3dbe1d983f101a5efba2fd4f2e (diff)
Automated g4 rollback of changelist 192180356
PiperOrigin-RevId: 193427566
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc19
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc27
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc7
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h1
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc9
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc12
16 files changed, 10 insertions, 86 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 8d26938c6e..8e785de68c 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1412,7 +1412,6 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
return Status::OK();
}
-// TODO(b/74536353): do this simplification for BroadcastDimOne as well.
StatusOr<bool> AlgebraicSimplifierVisitor::
TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
HloInstruction* reshape_or_broadcast) {
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 3f7089d6ca..56723e7650 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -199,7 +199,6 @@ class DfsHloVisitorBase {
virtual Status HandleReduce(HloInstructionPtr hlo) = 0;
virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
- virtual Status HandleBroadcastDimOne(HloInstructionPtr hlo) = 0;
virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index e6680ee9b8..240faebe62 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -158,9 +158,6 @@ class DfsHloVisitorWithDefaultBase
Status HandleBroadcast(HloInstructionPtr broadcast) override {
return DefaultAction(broadcast);
}
- Status HandleBroadcastDimOne(HloInstructionPtr broadcastDimOne) override {
- return DefaultAction(broadcastDimOne);
- }
Status HandlePad(HloInstructionPtr pad) override {
return DefaultAction(pad);
}
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 7aa38c6b79..35ecd4428d 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -69,8 +69,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
// Broadcasts dramatically increase the size of constants, which is often
// detrimental to performance and memory capacity, so do not fold
// broadcasts.
- if (instruction->opcode() == HloOpcode::kBroadcast ||
- instruction->opcode() == HloOpcode::kBroadcastDimOne) {
+ if (instruction->opcode() == HloOpcode::kBroadcast) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index ea4dd62fdb..44e4f75f75 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -336,11 +336,6 @@ Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
return Status::OK();
}
-Status HloCostAnalysis::HandleBroadcastDimOne(
- const HloInstruction* broadcastDimOne) {
- return Status::OK();
-}
-
Status HloCostAnalysis::HandlePad(const HloInstruction*) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index a9f6845747..d17678d20f 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -95,7 +95,6 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleSelectAndScatter(const HloInstruction* instruction) override;
Status HandleBitcast(const HloInstruction* bitcast) override;
Status HandleBroadcast(const HloInstruction* broadcast) override;
- Status HandleBroadcastDimOne(const HloInstruction* broadcastDimOne) override;
Status HandlePad(const HloInstruction* pad) override;
Status HandleReshape(const HloInstruction* reshape) override;
Status HandleTranspose(const HloInstruction* transpose) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index c35783c456..25702dc65e 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -956,7 +956,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kTuple:
return kWhite;
case HloOpcode::kBroadcast:
- case HloOpcode::kBroadcastDimOne:
// De-emphasize nodes which broadcast a scalar within a fusion node --
// these are essentially free.
if (instr->IsFused() &&
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 56cb241087..a445380817 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -701,15 +701,6 @@ HloInstruction::CreateSelectAndScatter(
}
/* static */ std::unique_ptr<HloInstruction>
-HloInstruction::CreateBroadcastDimOne(const Shape& shape,
- HloInstruction* operand) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBroadcastDimOne, shape));
- instruction->AppendOperand(operand);
- return instruction;
-}
-
-/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBroadcastSequence(
const Shape& output_shape, HloInstruction* operand,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
@@ -1311,10 +1302,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateBroadcast(shape, new_operands[0], dimensions_);
break;
- case HloOpcode::kBroadcastDimOne:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateBroadcastDimOne(shape, new_operands[0]);
- break;
case HloOpcode::kCall:
clone = CreateCall(shape, new_operands, to_apply());
break;
@@ -1863,8 +1850,6 @@ bool HloInstruction::IdenticalSlowPath(
// Remaining instructions with special values.
case HloOpcode::kBitcast:
- case HloOpcode::kBroadcastDimOne:
- case HloOpcode::kDynamicUpdateSlice:
return eq_shapes(shape(), other.shape());
case HloOpcode::kBroadcast:
return eq_shapes(shape(), other.shape()) &&
@@ -1883,6 +1868,8 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kDynamicSlice:
return eq_shapes(shape(), other.shape()) &&
dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
+ case HloOpcode::kDynamicUpdateSlice:
+ return eq_shapes(shape(), other.shape());
case HloOpcode::kCall:
case HloOpcode::kMap:
return eq_computations(to_apply(), other.to_apply());
@@ -2692,8 +2679,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleBitcast(this);
case HloOpcode::kBroadcast:
return visitor->HandleBroadcast(this);
- case HloOpcode::kBroadcastDimOne:
- return visitor->HandleBroadcastDimOne(this);
case HloOpcode::kPad:
return visitor->HandlePad(this);
case HloOpcode::kReshape:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 49aa075029..5a7394f7a6 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -401,10 +401,6 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- // Creates a broadcast-size-one-dimensions instruction.
- static std::unique_ptr<HloInstruction> CreateBroadcastDimOne(
- const Shape& shape, HloInstruction* operand);
-
// Creates a sequence of instructions that performs an explicit broadcast of
// the operand to the target shape.
//
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index dddc72480f..af24604c39 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -54,7 +54,6 @@ namespace xla {
V(kBitcast, "bitcast") \
V(kBitcastConvert, "bitcast-convert") \
V(kBroadcast, "broadcast") \
- V(kBroadcastDimOne, "broadcast-dim-one") \
V(kCall, "call", kHloOpcodeIsVariadic) \
V(kCeil, "ceil") \
V(kClamp, "clamp") \
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 63ec5964eb..8c875698eb 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -174,34 +174,17 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape()));
TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
broadcast->dimensions().size());
- for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
- int64 output_dimension = broadcast->dimensions()[i];
+ for (int64 operand_dimension = 0;
+ operand_dimension < ShapeUtil::Rank(operand_shape);
+ ++operand_dimension) {
+ int64 output_dimension = broadcast->dimensions()[operand_dimension];
TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) ==
- operand_shape.dimensions(i))
+ operand_shape.dimensions(operand_dimension))
<< broadcast->ToString() << " operand shape " << operand_shape;
}
return tensorflow::Status::OK();
}
-Status ShapeVerifier::HandleBroadcastDimOne(HloInstruction* broadcastDimOne) {
- const Shape& operand_shape = broadcastDimOne->operand(0)->shape();
- int64 operand_rank = ShapeUtil::Rank(operand_shape);
- const Shape& output_shape = broadcastDimOne->shape();
- // Check for mixed precision.
- TF_RETURN_IF_ERROR(CheckShape(broadcastDimOne, output_shape));
- TF_RET_CHECK(operand_rank == ShapeUtil::Rank(output_shape));
- for (int64 i = 0; i < operand_rank; ++i) {
- int64 operand_dimension = operand_shape.dimensions(i);
- int64 output_dimension = output_shape.dimensions(i);
- TF_RET_CHECK(operand_dimension == 1 ||
- operand_dimension == output_dimension)
- << "Dimension " << i << " of broadcastDimOne "
- << broadcastDimOne->ToString() << " is " << operand_dimension
- << ", expected 1 or " << output_dimension;
- }
- return tensorflow::Status::OK();
-}
-
Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
// Check for mixed precision.
TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape()));
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index a4dff977ba..1dd7ec3c51 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -54,7 +54,6 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleReduce(HloInstruction* reduce) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleBroadcast(HloInstruction* broadcast) override;
- Status HandleBroadcastDimOne(HloInstruction* broadcastDimOne) override;
Status HandleReshape(HloInstruction* reshape) override;
Status HandleTranspose(HloInstruction* transpose) override;
Status HandleParameter(HloInstruction*) override;
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 3f4dbf897d..d69ad80bdb 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -37,7 +37,6 @@ namespace xla {
case HloOpcode::kBitcast:
case HloOpcode::kBitcastConvert:
case HloOpcode::kBroadcast:
- case HloOpcode::kBroadcastDimOne:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kComplex:
@@ -143,8 +142,7 @@ bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) {
});
return std::count_if(hlo->operands().begin(), hlo->operands().end(),
[output_rank](HloInstruction* operand) {
- if (operand->opcode() == HloOpcode::kBroadcast ||
- operand->opcode() == HloOpcode::kBroadcastDimOne) {
+ if (operand->opcode() == HloOpcode::kBroadcast) {
return false;
}
if (operand->opcode() == HloOpcode::kConstant &&
@@ -249,8 +247,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
auto reachability = computation->ComputeReachability();
auto cheap_to_duplicate = [this](HloInstruction* producer) {
- if (producer->opcode() == HloOpcode::kBroadcast ||
- producer->opcode() == HloOpcode::kBroadcastDimOne) {
+ if (producer->opcode() == HloOpcode::kBroadcast) {
return true;
}
if (producer->opcode() == HloOpcode::kConstant &&
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index f5a4f2c9df..586f6ef7a9 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -879,7 +879,6 @@ XLA_UNOP_PATTERN(Abs)
XLA_UNOP_PATTERN(RoundNearestAfz)
XLA_UNOP_PATTERN(Bitcast)
XLA_UNOP_PATTERN(Broadcast)
-XLA_UNOP_PATTERN(BroadcastDimOne)
XLA_UNOP_PATTERN(Ceil)
XLA_UNOP_PATTERN(Copy)
XLA_UNOP_PATTERN(Cos)
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index b2f122982a..e60a5a4919 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -724,15 +724,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
shape, operands[0], *broadcast_dimensions));
break;
}
- case HloOpcode::kBroadcastDimOne: {
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(
- HloInstruction::CreateBroadcastDimOne(shape, operands[0]));
- break;
- }
case HloOpcode::kConcatenate: {
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 57684b5834..adc8b1d620 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -59,18 +59,6 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
)"
},
-// broadcast size-one dimensions
-{
-"BroadcastDimOne",
-R"(HloModule broadcast_dim_one_module
-
-ENTRY %broadcast-dim-one () -> f32[2,2] {
- %constant = f32[1,2]{1,0} constant(f32[1,2] { { 1.1, 2.2 } })
- ROOT %broadcast-dim-one = f32[2,2]{1,0} broadcast-dim-one(f32[1,2]{1,0} %constant)
-}
-
-)"
-},
// pred constant
{
"ConstantPred",