aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc39
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h11
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc8
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h2
13 files changed, 46 insertions, 66 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 1302026ccf..0187c09d7b 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -126,10 +126,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
- Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+ Status HandleCopy(HloInstruction* copy) override;
- Status HandleConvert(HloInstruction* convert,
- HloInstruction* operand) override;
+ Status HandleConvert(HloInstruction* convert) override;
Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
HloInstruction* rhs, const Window& window) override;
@@ -179,11 +178,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs,
HloInstruction* rhs) override;
- Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
- HloInstruction* rhs) override;
-
- Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs,
- HloInstruction* rhs) override;
+ Status HandleMaximum(HloInstruction* maximum) override;
+ Status HandleMinimum(HloInstruction* minimum) override;
// Returns whether algebraic simplification has occurred.
const bool changed() const { return changed_; }
@@ -334,16 +330,16 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add,
return Status::OK();
}
-Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy,
- HloInstruction* operand) {
+Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
// If a copy feeds a copy, make it a single copy.
- if (operand->opcode() == HloOpcode::kCopy) {
+ if (copy->operand(0)->opcode() == HloOpcode::kCopy) {
return ReplaceWithNewInstruction(
- copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy,
- operand->operands()[0]));
+ copy, HloInstruction::CreateUnary(
+ copy->shape(), HloOpcode::kCopy,
+ copy->mutable_operand(0)->mutable_operand(0)));
}
// All copies can be eliminated (assuming layout constraints are satisified).
- ReplaceInstructionIfSameShape(copy, operand);
+ ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0));
return Status::OK();
}
@@ -792,12 +788,11 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
// A conversion to the same element type as the operand is a nop and can be
// removed. A conversion of a constant can be simplified by making a new
// constant.
-Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert,
- HloInstruction* operand) {
- PrimitiveType src_type = operand->shape().element_type();
+Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
+ PrimitiveType src_type = convert->operand(0)->shape().element_type();
PrimitiveType dest_type = convert->shape().element_type();
if (src_type == dest_type) {
- return ReplaceInstruction(convert, operand);
+ return ReplaceInstruction(convert, convert->mutable_operand(0));
}
return Status::OK();
}
@@ -1391,9 +1386,7 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
return true;
}
-Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum,
- HloInstruction* lhs,
- HloInstruction* rhs) {
+Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
// Match the following tree:
// min_operand operand
// \ /
@@ -1424,9 +1417,7 @@ Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum,
return Status::OK();
}
-Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum,
- HloInstruction* lhs,
- HloInstruction* rhs) {
+Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
// Match the following tree:
// max_operand operand
// \ /
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 421618d819..fee5fd8830 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -259,12 +259,12 @@ Status IrEmitter::HandleConstant(HloInstruction* constant,
return Status::OK();
}
-Status IrEmitter::HandleCopy(HloInstruction* copy, HloInstruction* operand) {
+Status IrEmitter::HandleCopy(HloInstruction* copy) {
if (ShapeUtil::IsTuple(copy->shape())) {
// kCopy shallow copies a tuple so just memcpy the top-level buffer.
TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy));
emitted_value_[copy] = copy_value;
- return EmitMemcpy(*operand, *copy);
+ return EmitMemcpy(*(copy->operand(0)), *copy);
} else {
// Use the elemental emitter for non-tuple shapes.
return DefaultAction(copy);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 1a8c91efd4..a1b7bd9e6d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -96,7 +96,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleConstant(HloInstruction* constant,
const Literal& literal) override;
- Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+ Status HandleCopy(HloInstruction* copy) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element,
HloInstruction* operand) override;
Status HandleSelect(HloInstruction* select, HloInstruction* pred,
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index ea7c22737f..3f9b71cf2b 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -72,22 +72,19 @@ class DfsHloVisitor {
virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred,
HloInstruction* on_true,
HloInstruction* on_false) = 0;
- virtual Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
- HloInstruction* rhs) {
+ virtual Status HandleMaximum(HloInstruction* maximum) {
return HandleElementwiseBinary(maximum, HloOpcode::kMaximum);
}
- virtual Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs,
- HloInstruction* rhs) {
+ virtual Status HandleMinimum(HloInstruction* minimum) {
return HandleElementwiseBinary(minimum, HloOpcode::kMinimum);
}
virtual Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0;
- virtual Status HandleConvert(HloInstruction* convert,
- HloInstruction* operand) {
+ virtual Status HandleConvert(HloInstruction* convert) {
return HandleElementwiseUnary(convert, HloOpcode::kConvert);
}
- virtual Status HandleCopy(HloInstruction* copy, HloInstruction* operand) {
+ virtual Status HandleCopy(HloInstruction* copy) {
return HandleElementwiseUnary(copy, HloOpcode::kCopy);
}
virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 75910b8cbb..2970ba8cc4 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -64,12 +64,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
tensorflow::gtl::ArraySlice<HloInstruction*> /*operands*/) override {
return DefaultAction(concatenate);
}
- Status HandleConvert(HloInstruction* convert,
- HloInstruction* /*operand*/) override {
+ Status HandleConvert(HloInstruction* convert) override {
return DefaultAction(convert);
}
- Status HandleCopy(HloInstruction* copy,
- HloInstruction* /*operand*/) override {
+ Status HandleCopy(HloInstruction* copy) override {
return DefaultAction(copy);
}
Status HandleSelect(HloInstruction* select, HloInstruction* /*pred*/,
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 607a366ac6..de72ac738e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -231,7 +231,7 @@ class IrEmitterUnnested : public IrEmitter {
// IrEmitterUnnested handles the following instructions differently from
// IrEmitter.
- Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+ Status HandleCopy(HloInstruction* copy) override;
Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
HloInstruction* rhs, const Window& window) override;
Status HandleDot(HloInstruction* dot, HloInstruction* lhs_instruction,
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index ab04d1736e..ea71d92417 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -722,8 +722,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
} // namespace
-Status IrEmitterUnnested::HandleCopy(HloInstruction* copy,
- HloInstruction* operand) {
+Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
if (ImplementedAsMemcpy(*copy)) {
thunk_sequence_->emplace_back(BuildCopyThunk(copy));
return Status::OK();
@@ -731,7 +730,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy,
bool is_transpose_021;
Shape reduced_input_shape, reduced_output_shape;
std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) =
- IsTranspose021(operand->shape(), copy->shape());
+ IsTranspose021(copy->operand(0)->shape(), copy->shape());
if (is_transpose_021 &&
reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled &&
reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) {
@@ -739,7 +738,8 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy,
VLOG(3) << "Emitting tiled 0-2-1 transposition";
constexpr int64 tile_size = 32;
int64 num_tiles = EmitTranspose021Tiled(
- GetIrArray(*operand).CastToShape(reduced_input_shape, &ir_builder_),
+ GetIrArray(*(copy->operand(0)))
+ .CastToShape(reduced_input_shape, &ir_builder_),
GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_),
tile_size, &ir_builder_);
UpdateLaunchDimensions(LaunchDimensions(num_tiles, tile_size), LastThunk(),
@@ -747,7 +747,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy,
return Status::OK();
}
- return IrEmitter::HandleCopy(copy, operand);
+ return IrEmitter::HandleCopy(copy);
}
Status IrEmitterUnnested::EmitColumnReduction(
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index abbbbfa02b..f3a6cd43c2 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -166,13 +166,11 @@ Status HloCostAnalysis::HandleConcatenate(
return Status::OK();
}
-Status HloCostAnalysis::HandleConvert(HloInstruction* convert,
- HloInstruction* operand) {
+Status HloCostAnalysis::HandleConvert(HloInstruction* convert) {
return HandleElementwiseOp(convert);
}
-Status HloCostAnalysis::HandleCopy(HloInstruction* copy,
- HloInstruction* operand) {
+Status HloCostAnalysis::HandleCopy(HloInstruction* copy) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 6538266864..3f0dfcc619 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -63,9 +63,8 @@ class HloCostAnalysis : public DfsHloVisitor {
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status HandleSend(HloInstruction* send) override;
Status HandleRecv(HloInstruction* recv) override;
- Status HandleConvert(HloInstruction* convert,
- HloInstruction* operand) override;
- Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+ Status HandleConvert(HloInstruction* convert) override;
+ Status HandleCopy(HloInstruction* copy) override;
Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
HloInstruction* rhs) override;
Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index a42289590b..4936b823a2 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -192,7 +192,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
- Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override {
+ Status HandleCopy(HloInstruction* copy) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy],
ElementWiseUnaryOp(copy, [](ReturnT elem_operand) {
return elem_operand;
@@ -208,8 +208,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
typename primitive_util::PrimitiveTypeToNative<dest_type>::type>();
}
- Status HandleConvert(HloInstruction* convert,
- HloInstruction* operand) override {
+ Status HandleConvert(HloInstruction* convert) override {
+ const HloInstruction* operand = convert->operand(0);
auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
switch (operand->shape().element_type()) {
@@ -337,8 +337,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
- Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
- HloInstruction* rhs) override {
+ Status HandleMaximum(HloInstruction* maximum) override {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[maximum],
ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) {
@@ -347,8 +346,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
- Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs,
- HloInstruction* rhs) override {
+ Status HandleMinimum(HloInstruction* minimum) override {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[minimum],
ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index c49c00bac0..99b73dea29 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1803,9 +1803,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
case HloOpcode::kSubtract:
return visitor->HandleSubtract(this, operands_[0], operands_[1]);
case HloOpcode::kMaximum:
- return visitor->HandleMaximum(this, operands_[0], operands_[1]);
+ return visitor->HandleMaximum(this);
case HloOpcode::kMinimum:
- return visitor->HandleMinimum(this, operands_[0], operands_[1]);
+ return visitor->HandleMinimum(this);
case HloOpcode::kLogicalAnd:
return visitor->HandleLogicalAnd(this, operands_[0], operands_[1]);
case HloOpcode::kLogicalOr:
@@ -1813,9 +1813,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
case HloOpcode::kConcatenate:
return visitor->HandleConcatenate(this, operands_);
case HloOpcode::kConvert:
- return visitor->HandleConvert(this, operands_[0]);
+ return visitor->HandleConvert(this);
case HloOpcode::kCopy:
- return visitor->HandleCopy(this, operands_[0]);
+ return visitor->HandleCopy(this);
case HloOpcode::kMultiply:
return visitor->HandleMultiply(this, operands_[0], operands_[1]);
case HloOpcode::kDot:
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index ad6f015c70..8d68398450 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -243,12 +243,11 @@ Status TuplePointsToAnalysis::HandleGetTupleElement(
return Status::OK();
}
-Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy,
- HloInstruction* operand) {
+Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) {
// A kCopy instruction performs a shallow copy of the operand. The top-level
// buffer (index={}) is newly created, but all other buffers (in the case of a
// tuple shape) come from the operand
- PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, operand);
+ PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0));
points_to_set.mutable_element(/*index=*/{})->clear();
points_to_set.AddPointedToBuffer(NewLogicalBuffer(copy, /*index=*/{}),
/*index=*/{});
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index 4d7fc7cbc9..bab4235a28 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -208,7 +208,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
Status HandleGetTupleElement(HloInstruction* get_tuple_element,
HloInstruction* operand) override;
Status HandleBitcast(HloInstruction* bitcast) override;
- Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+ Status HandleCopy(HloInstruction* copy) override;
Status HandleSelect(HloInstruction* select, HloInstruction* pred,
HloInstruction* on_true,
HloInstruction* on_false) override;