diff options
13 files changed, 46 insertions, 66 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 1302026ccf..0187c09d7b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -126,10 +126,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* concatenate, tensorflow::gtl::ArraySlice<HloInstruction*> operands) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; + Status HandleConvert(HloInstruction* convert) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; @@ -179,11 +178,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs, HloInstruction* rhs) override; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) override; - - Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleMaximum(HloInstruction* maximum) override; + Status HandleMinimum(HloInstruction* minimum) override; // Returns whether algebraic simplification has occurred. const bool changed() const { return changed_; } @@ -334,16 +330,16 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. - if (operand->opcode() == HloOpcode::kCopy) { + if (copy->operand(0)->opcode() == HloOpcode::kCopy) { return ReplaceWithNewInstruction( - copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, - operand->operands()[0])); + copy, HloInstruction::CreateUnary( + copy->shape(), HloOpcode::kCopy, + copy->mutable_operand(0)->mutable_operand(0))); } // All copies can be eliminated (assuming layout constraints are satisified). - ReplaceInstructionIfSameShape(copy, operand); + ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); return Status::OK(); } @@ -792,12 +788,11 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A conversion to the same element type as the operand is a nop and can be // removed. A conversion of a constant can be simplified by making a new // constant. -Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - PrimitiveType src_type = operand->shape().element_type(); +Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { + PrimitiveType src_type = convert->operand(0)->shape().element_type(); PrimitiveType dest_type = convert->shape().element_type(); if (src_type == dest_type) { - return ReplaceInstruction(convert, operand); + return ReplaceInstruction(convert, convert->mutable_operand(0)); } return Status::OK(); } @@ -1391,9 +1386,7 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( return true; } -Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { // Match the following tree: // min_operand operand // \ / @@ -1424,9 +1417,7 @@ Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { // Match the following tree: // max_operand operand // \ / diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 421618d819..fee5fd8830 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -259,12 +259,12 @@ Status IrEmitter::HandleConstant(HloInstruction* constant, return Status::OK(); } -Status IrEmitter::HandleCopy(HloInstruction* copy, HloInstruction* operand) { +Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy)); emitted_value_[copy] = copy_value; - return EmitMemcpy(*operand, *copy); + return EmitMemcpy(*(copy->operand(0)), *copy); } else { // Use the elemental emitter for non-tuple shapes. return DefaultAction(copy); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 1a8c91efd4..a1b7bd9e6d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -96,7 +96,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index ea7c22737f..3f9b71cf2b 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -72,22 +72,19 @@ class DfsHloVisitor { virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false) = 0; - virtual Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) { + virtual Status HandleMaximum(HloInstruction* maximum) { return HandleElementwiseBinary(maximum, HloOpcode::kMaximum); } - virtual Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) { + virtual Status HandleMinimum(HloInstruction* minimum) { return HandleElementwiseBinary(minimum, HloOpcode::kMinimum); } virtual Status HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0; - virtual Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) { + virtual Status HandleConvert(HloInstruction* convert) { return HandleElementwiseUnary(convert, HloOpcode::kConvert); } - virtual Status HandleCopy(HloInstruction* copy, HloInstruction* operand) { + virtual Status HandleCopy(HloInstruction* copy) { return HandleElementwiseUnary(copy, HloOpcode::kCopy); } virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 75910b8cbb..2970ba8cc4 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -64,12 +64,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { tensorflow::gtl::ArraySlice<HloInstruction*> /*operands*/) override { return DefaultAction(concatenate); } - Status HandleConvert(HloInstruction* convert, - HloInstruction* /*operand*/) override { + Status HandleConvert(HloInstruction* convert) override { return DefaultAction(convert); } - Status HandleCopy(HloInstruction* copy, - HloInstruction* /*operand*/) override { + Status HandleCopy(HloInstruction* copy) override { return DefaultAction(copy); } Status HandleSelect(HloInstruction* select, HloInstruction* /*pred*/, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 607a366ac6..de72ac738e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -231,7 +231,7 @@ class IrEmitterUnnested : public IrEmitter { // IrEmitterUnnested handles the following instructions differently from // IrEmitter. - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; Status HandleDot(HloInstruction* dot, HloInstruction* lhs_instruction, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index ab04d1736e..ea71d92417 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -722,8 +722,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, } // namespace -Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { if (ImplementedAsMemcpy(*copy)) { thunk_sequence_->emplace_back(BuildCopyThunk(copy)); return Status::OK(); @@ -731,7 +730,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, bool is_transpose_021; Shape reduced_input_shape, reduced_output_shape; std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) = - IsTranspose021(operand->shape(), copy->shape()); + IsTranspose021(copy->operand(0)->shape(), copy->shape()); if (is_transpose_021 && reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled && reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) { @@ -739,7 +738,8 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, VLOG(3) << "Emitting tiled 0-2-1 transposition"; constexpr int64 tile_size = 32; int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*operand).CastToShape(reduced_input_shape, &ir_builder_), + GetIrArray(*(copy->operand(0))) + .CastToShape(reduced_input_shape, &ir_builder_), GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_), tile_size, &ir_builder_); UpdateLaunchDimensions(LaunchDimensions(num_tiles, tile_size), LastThunk(), @@ -747,7 +747,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, return Status::OK(); } - return IrEmitter::HandleCopy(copy, operand); + return IrEmitter::HandleCopy(copy); } Status IrEmitterUnnested::EmitColumnReduction( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index abbbbfa02b..f3a6cd43c2 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -166,13 +166,11 @@ Status HloCostAnalysis::HandleConcatenate( return Status::OK(); } -Status HloCostAnalysis::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { +Status HloCostAnalysis::HandleConvert(HloInstruction* convert) { return HandleElementwiseOp(convert); } -Status HloCostAnalysis::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status HloCostAnalysis::HandleCopy(HloInstruction* copy) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 6538266864..3f0dfcc619 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -63,9 +63,8 @@ class HloCostAnalysis : public DfsHloVisitor { tensorflow::gtl::ArraySlice<HloInstruction*> operands) override; Status HandleSend(HloInstruction* send) override; Status HandleRecv(HloInstruction* recv) override; - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleConvert(HloInstruction* convert) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index a42289590b..4936b823a2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -192,7 +192,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override { + Status HandleCopy(HloInstruction* copy) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy], ElementWiseUnaryOp(copy, [](ReturnT elem_operand) { return elem_operand; @@ -208,8 +208,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename primitive_util::PrimitiveTypeToNative<dest_type>::type>(); } - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override { + Status HandleConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); switch (operand->shape().element_type()) { @@ -337,8 +337,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleMaximum(HloInstruction* maximum) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { @@ -347,8 +346,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleMinimum(HloInstruction* minimum) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c49c00bac0..99b73dea29 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1803,9 +1803,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kSubtract: return visitor->HandleSubtract(this, operands_[0], operands_[1]); case HloOpcode::kMaximum: - return visitor->HandleMaximum(this, operands_[0], operands_[1]); + return visitor->HandleMaximum(this); case HloOpcode::kMinimum: - return visitor->HandleMinimum(this, operands_[0], operands_[1]); + return visitor->HandleMinimum(this); case HloOpcode::kLogicalAnd: return visitor->HandleLogicalAnd(this, operands_[0], operands_[1]); case HloOpcode::kLogicalOr: @@ -1813,9 +1813,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kConcatenate: return visitor->HandleConcatenate(this, operands_); case HloOpcode::kConvert: - return visitor->HandleConvert(this, operands_[0]); + return visitor->HandleConvert(this); case HloOpcode::kCopy: - return visitor->HandleCopy(this, operands_[0]); + return visitor->HandleCopy(this); case HloOpcode::kMultiply: return visitor->HandleMultiply(this, operands_[0], operands_[1]); case HloOpcode::kDot: diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index ad6f015c70..8d68398450 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -243,12 +243,11 @@ Status TuplePointsToAnalysis::HandleGetTupleElement( return Status::OK(); } -Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) { // A kCopy instruction performs a shallow copy of the operand. The top-level // buffer (index={}) is newly created, but all other buffers (in the case of a // tuple shape) come from the operand - PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, operand); + PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0)); points_to_set.mutable_element(/*index=*/{})->clear(); points_to_set.AddPointedToBuffer(NewLogicalBuffer(copy, /*index=*/{}), /*index=*/{}); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 4d7fc7cbc9..bab4235a28 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -208,7 +208,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) override; Status HandleBitcast(HloInstruction* bitcast) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false) override; |