From 76ee2cf584269c782961a7d835e9febd15522188 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 19 Jun 2017 16:23:20 -0700 Subject: Remove operand parameters from HandleElementwiseUnary and HandleElementwiseBinary functions. This allows incremental cleanup of the individual unary and binary operators. PiperOrigin-RevId: 159495454 --- tensorflow/compiler/xla/service/dfs_hlo_visitor.cc | 7 +-- tensorflow/compiler/xla/service/dfs_hlo_visitor.h | 60 ++++++++++------------ .../xla/service/dfs_hlo_visitor_with_default.h | 9 ++-- .../compiler/xla/service/hlo_cost_analysis.cc | 7 +-- .../compiler/xla/service/hlo_cost_analysis.h | 8 ++- 5 files changed, 39 insertions(+), 52 deletions(-) diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index b9a496be43..5121d36866 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -24,16 +24,13 @@ limitations under the License. namespace xla { Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* operand) { + HloOpcode opcode) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", HloOpcodeString(opcode).c_str()); } Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) { + HloOpcode opcode) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", HloOpcodeString(opcode).c_str()); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 1f58562ac2..ea7c22737f 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -65,11 +65,8 @@ class DfsHloVisitor { // These routines are self-descriptive, see class comment for usage // information. - virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand); - virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs); + virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode); + virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode); virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) = 0; virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred, @@ -77,31 +74,31 @@ class DfsHloVisitor { HloInstruction* on_false) = 0; virtual Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(maximum, HloOpcode::kMaximum, lhs, rhs); + return HandleElementwiseBinary(maximum, HloOpcode::kMaximum); } virtual Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(minimum, HloOpcode::kMinimum, lhs, rhs); + return HandleElementwiseBinary(minimum, HloOpcode::kMinimum); } virtual Status HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) = 0; virtual Status HandleConvert(HloInstruction* convert, HloInstruction* operand) { - return HandleElementwiseUnary(convert, HloOpcode::kConvert, operand); + return HandleElementwiseUnary(convert, HloOpcode::kConvert); } virtual Status HandleCopy(HloInstruction* copy, HloInstruction* operand) { - return HandleElementwiseUnary(copy, HloOpcode::kCopy, operand); + return HandleElementwiseUnary(copy, HloOpcode::kCopy); } virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(multiply, HloOpcode::kMultiply, lhs, rhs); + return HandleElementwiseBinary(multiply, HloOpcode::kMultiply); } virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) = 0; virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(power, HloOpcode::kPower, lhs, rhs); + return HandleElementwiseBinary(power, HloOpcode::kPower); } virtual Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, @@ -109,72 +106,71 @@ class DfsHloVisitor { virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0; virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(compare, opcode, lhs, rhs); + return HandleElementwiseBinary(compare, opcode); } virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(add, HloOpcode::kAdd, lhs, rhs); + return HandleElementwiseBinary(add, HloOpcode::kAdd); } virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(divide, HloOpcode::kDivide, lhs, rhs); + return HandleElementwiseBinary(divide, HloOpcode::kDivide); } virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(remainder, HloOpcode::kRemainder, lhs, rhs); + return HandleElementwiseBinary(remainder, HloOpcode::kRemainder); } virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(subtract, HloOpcode::kSubtract, lhs, rhs); + return HandleElementwiseBinary(subtract, HloOpcode::kSubtract); } virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { - return HandleElementwiseUnary(abs, HloOpcode::kAbs, operand); + return HandleElementwiseUnary(abs, HloOpcode::kAbs); } virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) { - return HandleElementwiseUnary(sign, HloOpcode::kSign, operand); + return HandleElementwiseUnary(sign, HloOpcode::kSign); } virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) { - return HandleElementwiseUnary(negate, HloOpcode::kNegate, operand); + return HandleElementwiseUnary(negate, HloOpcode::kNegate); } virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) { - return HandleElementwiseUnary(exp, HloOpcode::kExp, operand); + return HandleElementwiseUnary(exp, HloOpcode::kExp); } virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) { - return HandleElementwiseUnary(floor, HloOpcode::kFloor, operand); + return HandleElementwiseUnary(floor, HloOpcode::kFloor); } virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) { - return HandleElementwiseUnary(ceil, HloOpcode::kCeil, operand); + return HandleElementwiseUnary(ceil, HloOpcode::kCeil); } virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) { - return HandleElementwiseUnary(log, HloOpcode::kLog, operand); + return HandleElementwiseUnary(log, HloOpcode::kLog); } virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) { - return HandleElementwiseUnary(cos, HloOpcode::kCos, operand); + return HandleElementwiseUnary(cos, HloOpcode::kCos); } virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { - return HandleElementwiseUnary(tanh, HloOpcode::kTanh, operand); + return HandleElementwiseUnary(tanh, HloOpcode::kTanh); } virtual Status HandleIsFinite(HloInstruction* is_finite, HloInstruction* operand) { - return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite, operand); + return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite); } virtual Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd, lhs, - rhs); + return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd); } virtual Status HandleLogicalNot(HloInstruction* logical_not, HloInstruction* operand) { - return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot, operand); + return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot); } virtual Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr, lhs, rhs); + return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr); } virtual Status HandleReducePrecision(HloInstruction* reduce_precision, HloInstruction* operand) { - return HandleElementwiseUnary(reduce_precision, HloOpcode::kReducePrecision, - operand); + return HandleElementwiseUnary(reduce_precision, + HloOpcode::kReducePrecision); } virtual Status HandleInfeed(HloInstruction* infeed) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 1bcc03bae1..75910b8cbb 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -41,13 +41,12 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { // Default action performed on HloInstruction. virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0; - Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand) override { + Status HandleElementwiseUnary(HloInstruction* hlo, + HloOpcode opcode) override { return DefaultAction(hlo); } - Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleElementwiseBinary(HloInstruction* hlo, + HloOpcode opcode) override { return DefaultAction(hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index cbabf00913..46ca316fe6 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -75,15 +75,12 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { } Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* operand) { + HloOpcode opcode) { return HandleElementwiseOp(hlo); } Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) { + HloOpcode opcode) { return HandleElementwiseOp(hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index f14baf6da2..6538266864 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -42,11 +42,9 @@ class HloCostAnalysis : public DfsHloVisitor { explicit HloCostAnalysis(const ShapeSizeFunction& shape_size) : shape_size_(shape_size) {} - Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand) override; - Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) override; + Status HandleElementwiseBinary(HloInstruction* hlo, + HloOpcode opcode) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element, -- cgit v1.2.3