diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-08-29 13:23:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-29 13:26:41 -0700 |
commit | 065f9b833ffbb3b2f03d63febb186275674ba133 (patch) | |
tree | b4521b490c7b1af24cec410e2d3d02419428df64 | |
parent | 52773e6765649b9963985c81fc0612742ffac73b (diff) |
Automated rollback of commit 150dee25d82589ca109957cc996efbd2a236e044
PiperOrigin-RevId: 210778248
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.cc | 44 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.h | 19 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/bfloat16_normalization.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.cc | 18 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 193 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 29 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 9 |
13 files changed, 109 insertions, 247 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index ea53287068..819d324927 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1733,37 +1733,18 @@ XlaOp XlaBuilder::Reduce( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { - return Reduce(tensorflow::gtl::ArraySlice<XlaOp>({operand}), - tensorflow::gtl::ArraySlice<XlaOp>({init_value}), computation, - dimensions_to_reduce); -} - -XlaOp XlaBuilder::Reduce( - tensorflow::gtl::ArraySlice<XlaOp> operands, - tensorflow::gtl::ArraySlice<XlaOp> init_values, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); - std::vector<XlaOp> all_operands; - all_operands.insert(all_operands.end(), operands.begin(), operands.end()); - all_operands.insert(all_operands.end(), init_values.begin(), - init_values.end()); - - std::vector<const Shape*> operand_shape_ptrs; - TF_ASSIGN_OR_RETURN(const auto& operand_shapes, - GetOperandShapes(all_operands)); - absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); - - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferReduceShape( - operand_shape_ptrs, dimensions_to_reduce, called_program_shape)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferReduceShape( + {&operand_shape, &init_shape}, dimensions_to_reduce, + called_program_shape)); for (int64 dim : dimensions_to_reduce) { instr.add_dimensions(dim); @@ -1771,7 +1752,8 @@ XlaOp XlaBuilder::Reduce( AddCalledComputation(computation, &instr); - return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands); + return AddInstruction(std::move(instr), HloOpcode::kReduce, + {operand, init_value}); }); } @@ -2788,16 +2770,6 @@ XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, dimensions_to_reduce); } -// Reduces several arrays simultaneously among the provided dimensions, given -// "computation" as a reduction operator. -XlaOp Reduce(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands, - tensorflow::gtl::ArraySlice<XlaOp> init_values, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { - return builder->Reduce(operands, init_values, computation, - dimensions_to_reduce); -} - XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { return operand.builder()->ReduceAll(operand, init_value, computation); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 9b82cc03b3..193d8ed071 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -659,13 +659,6 @@ class XlaBuilder { const XlaComputation& computation, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); - // Reduces several arrays simultaneously among the provided dimensions, given - // "computation" as a reduction operator. - XlaOp Reduce(tensorflow::gtl::ArraySlice<XlaOp> operands, - tensorflow::gtl::ArraySlice<XlaOp> init_values, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); - // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, @@ -1256,11 +1249,6 @@ class XlaBuilder { friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); - friend XlaOp Reduce(XlaBuilder* builder, - tensorflow::gtl::ArraySlice<XlaOp> operands, - tensorflow::gtl::ArraySlice<XlaOp> init_values, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation); friend XlaOp ReduceWindow( @@ -1835,13 +1823,6 @@ XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); -// Reduces several arrays simultaneously among the provided dimensions, given -// "computation" as a reduction operator. -XlaOp Reduce(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands, - tensorflow::gtl::ArraySlice<XlaOp> init_values, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); - // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index a6f77db3b0..32573ed355 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -359,7 +359,6 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kConditional) { return Status::OK(); } - // TODO(b/112040122): Correctly normalize variadic reduce. if ((hlo->opcode() == HloOpcode::kSort || hlo->opcode() == HloOpcode::kCrossReplicaSum) && ShapeUtil::IsTuple(hlo->shape())) { diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 939b5114c3..0e12a1ee03 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -274,21 +274,15 @@ Status HloCostAnalysis::HandleMap(const HloInstruction* map) { } Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { + auto arg = reduce->operand(0); HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, ProcessSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. - // This counts the number of times the reduction function is applied, so it - // does not need to be multiplied by the number of input tensors - that's - // already "priced in" by the sub-computation doing more work. - auto arg = reduce->operand(0); - auto output_shape = ShapeUtil::IsArray(reduce->shape()) - ? reduce->shape() - : reduce->shape().tuple_shapes(0); - int64 reduction_count = - ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape); + int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) - + ShapeUtil::ElementsIn(reduce->shape()); for (const auto& property : sub_properties) { if (property.first != kBytesAccessedKey) { current_properties_[property.first] = property.second * reduction_count; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index c25869f87b..71f91fde93 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1262,7 +1262,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); if (sort_dim != rank - 1) { return Unimplemented( - "Trying to sort along dimension %d, which is not the last " + "Trying to support along dimension %d, which is not the last " "dimension", sort_dim); } @@ -1281,22 +1281,6 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { } } -Status HloEvaluator::HandleReduce(HloInstruction* reduce) { - if (!ShapeUtil::IsTuple(reduce->shape())) { - return DefaultAction(reduce); - } else { - auto first_element_type = reduce->shape().tuple_shapes(0).element_type(); - for (const auto& tuple_shape : reduce->shape().tuple_shapes()) { - if (tuple_shape.element_type() != first_element_type) { - return Unimplemented( - "Reduce with several outputs that have mixed element types is " - "unsupported"); - } - } - return reduce->Visit(typed_visitors_.at(first_element_type).get()); - } -} - Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return ShapeUtil::ValidateShape(hlo->shape()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 980a7fb9fa..0ea7089552 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -185,8 +185,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSort(HloInstruction* sort) override; - Status HandleReduce(HloInstruction* reduce) override; - // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 4edcb05f83..f682e69ee9 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1577,20 +1577,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleSort<ReturnT>(sort); } - Status HandleReduce(HloInstruction* hlo) override { - HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo); - int64 num_args = reduce->inputs().size(); - bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape()); + Status HandleReduce(HloInstruction* reduce) override { + // TODO(b/112040122): Support variadic reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Variadic reduce is not supported in the Evaluator"); + } + auto arg = reduce->operand(0); + auto init_value = reduce->operand(1); tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); - - absl::InlinedVector<const Shape*, 1> operand_shapes; - for (const HloInstruction* operand : reduce->operands()) { - operand_shapes.push_back(&operand->shape()); - } + TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == + ShapeUtil::Rank(arg->shape()) - dimensions.size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferReduceShape( - operand_shapes, + {&arg->shape(), &init_value->shape()}, /*dimensions_to_reduce=*/dimensions, /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) @@ -1598,23 +1598,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - absl::InlinedVector<const Literal*, 1> arg_literals(num_args); - absl::InlinedVector<const Literal*, 1> init_literals(num_args); - for (int64 i = 0; i < num_args; ++i) { - arg_literals[i] = &parent_->GetEvaluatedLiteralFor(reduce->inputs()[i]); - VLOG(3) << "HandleReduce arg_literal: " << arg_literals[i]->ToString(); - init_literals[i] = - &parent_->GetEvaluatedLiteralFor(reduce->init_values()[i]); - VLOG(3) << "HandleReduce init_literal: " << init_literals[i]->ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literals[i]->shape())); - } + const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); + VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); + const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); + VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get<ReturnT>({}); - // All args and results have the same dimensions, so pick an arbitrary one. - const Shape& arg_shape = arg_literals[0]->shape(); - const Shape& result_shape = ShapeUtil::IsTuple(reduce->shape()) - ? reduce->shape().tuple_shapes(0) - : reduce->shape(); - const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions()); + const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); std::vector<int64> arg_dim_steps(arg_dimensions.size()); std::vector<int64> arg_dim_counts(arg_dimensions.size()); for (const int64 dim : dimensions) { @@ -1632,109 +1623,63 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args); - for (int64 i = 0; i < num_args; ++i) { - results[i] = absl::make_unique<Literal>(result_shape); - } - + auto result = absl::make_unique<Literal>(reduce->shape()); Status eval_status; - // For each resulting dimension, calculate and assign computed values. - // This is really wasteful when num_args > 1, since we re-run the - // reduction num_args time. The alternative is to teach Populate() about - // tuples, which we should probably do. - absl::InlinedVector<ReturnT, 1> init_scalars(num_args); - for (int i = 0; i < num_args; ++i) { - init_scalars[i] = init_literals[i]->Get<ReturnT>({}); - } - - for (int64 input = 0; input < num_args; ++input) { - TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>( - [&](tensorflow::gtl::ArraySlice<int64> multi_index) { - if (!eval_status.ok()) { - return init_scalars[input]; - } - absl::InlinedVector<ReturnT, 1> result_values(init_scalars.begin(), - init_scalars.end()); - std::vector<int64> base(arg_dimensions.size()); - for (int64 i = 0; i < multi_index.size(); ++i) { - base[result_to_arg_index[i]] = multi_index[i]; - } + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> multi_index) { + ReturnT result_val = init_scalar; + if (!eval_status.ok()) { + return result_val; + } - // When the reduction is addition of floats, accumulate in a double - // for better precision. Also, avoid creating Literals for the - // intermediate results; it's much faster. - if (ShapeUtil::ElementIsFloating(init_literals[0]->shape()) && - IsScalarAdd(function)) { - CHECK_EQ(num_args, 1); - double computed_result = 0; - auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) { - computed_result += - GetAsDouble<ReturnT>(*arg_literals[0], input_index); - return true; - }; - ShapeUtil::ForEachIndex(arg_literals[0]->shape(), base, - arg_dim_counts, arg_dim_steps, func); - return static_cast<ReturnT>(computed_result); - } - auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) - -> StatusOr<bool> { - absl::InlinedVector<ReturnT, 1> arg_values(num_args); - for (int64 i = 0; i < num_args; ++i) { - arg_values[i] = arg_literals[i]->Get<ReturnT>(input_index); - } + std::vector<int64> base(arg_dimensions.size()); + for (int64 i = 0; i < multi_index.size(); ++i) { + base[result_to_arg_index[i]] = multi_index[i]; + } - // Evaluate computation with specified literal operands. - absl::InlinedVector<std::unique_ptr<Literal>, 1> - embedded_operands; - for (ReturnT value : result_values) { - embedded_operands.push_back( - LiteralUtil::CreateR0<ReturnT>(value)); - } - for (ReturnT value : arg_values) { - embedded_operands.push_back( - LiteralUtil::CreateR0<ReturnT>(value)); - } - absl::InlinedVector<Literal*, 1> embedded_operands_ptrs( - embedded_operands.size()); - std::transform(embedded_operands.begin(), embedded_operands.end(), - embedded_operands_ptrs.begin(), - [](const std::unique_ptr<Literal>& ptr) { - return ptr.get(); - }); - - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result, - embedded_evaluator.Evaluate<const Literal*>( - *function, embedded_operands_ptrs)); - // Clear visit states so that we can use the evaluator again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. - if (!has_tuple_output) { - result_values[0] = computed_result->Get<ReturnT>({}); - } else { - for (int64 i = 0; i < num_args; ++i) { - result_values[i] = computed_result->Get<ReturnT>( - /*multi_index=*/{}, /*shape_index=*/{i}); - } - } + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literal.shape()) && + IsScalarAdd(function)) { + double computed_result = 0; + auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) { + computed_result += GetAsDouble<ReturnT>(arg_literal, input_index); return true; }; - // Computes one element of the result, reducing all dimensions that - // contribute to that element. - eval_status = ShapeUtil::ForEachIndexWithStatus( - arg_shape, base, arg_dim_counts, arg_dim_steps, func); - return result_values[input]; - })); - } - if (!has_tuple_output) { - parent_->evaluated_[reduce] = std::move(results[0]); - } else { - auto tuple_result = absl::make_unique<Literal>(reduce->shape()); - for (int64 i = 0; i < num_args; ++i) { - TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i})); - } - parent_->evaluated_[reduce] = std::move(tuple_result); - } + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return static_cast<ReturnT>(computed_result); + } + auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) + -> StatusOr<bool> { + auto curr_val = arg_literal.Get<ReturnT>(input_index); + + // Evaluate computation with specified literal operands. + auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(curr_val); + auto result_val_literal = + LiteralUtil::CreateR0<ReturnT>(result_val); + + TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result, + embedded_evaluator.Evaluate<const Literal*>( + *function, {result_val_literal.get(), + curr_val_literal.get()})); + // Clear visit states so that we can use the evaluator again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + // Assign computed result to result_val. + result_val = computed_result->Get<ReturnT>({}); + return true; + }; + // Computes one element of the result, reducing all dimensions that + // contribute to that element. + eval_status = ShapeUtil::ForEachIndexWithStatus( + arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func); + return result_val; + })); + + parent_->evaluated_[reduce] = std::move(result); return eval_status; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 8a497e6edf..ed4e159910 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -158,26 +158,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); break; case HloOpcode::kReduce: - TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) - << "Reduce instruction should have an even number of operands but " - "sees " + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Reduce instruction should have 2 operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Reduce instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - { - const auto reduce_operands = all_operands(); - tensorflow::gtl::ArraySlice<HloInstruction*> inputs( - reduce_operands, 0, reduce_operands.size() / 2); - tensorflow::gtl::ArraySlice<HloInstruction*> init_values( - reduce_operands, reduce_operands.size() / 2, - reduce_operands.size()); - instruction = - CreateReduce(proto.shape(), inputs, init_values, - std::vector<int64>(proto.dimensions().begin(), - proto.dimensions().end()), - computations(0)); - } + instruction = CreateReduce(proto.shape(), operands(0), operands(1), + std::vector<int64>(proto.dimensions().begin(), + proto.dimensions().end()), + computations(0)); break; case HloOpcode::kSort: { TF_RET_CHECK(proto.operand_ids_size() == 1 || @@ -2759,13 +2749,10 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { case HloOpcode::kTranspose: return UseKind::kUsePermutingElements; case HloOpcode::kPad: + case HloOpcode::kReduce: // Pad reuses the padding value but not the padded array elements. + // Reduce reuses the init value but not the operand array elements. return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; - case HloOpcode::kReduce: - // Reduce reuses the init values but not the operand array elements. - return i >= Cast<HloReduceInstruction>(this)->input_count() - ? UseKind::kReuse - : UseKind::kUsePermutingElements; case HloOpcode::kFusion: // Uses the memoizing, recursive computation defined above. return FusionReusesParamElements::Compute(i, *fused_expression_root()); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 0b7f741d73..ffc74cfedd 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -586,7 +586,7 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size() % 2, 0); + CHECK_EQ(new_operands.size(), 2); return absl::make_unique<HloReduceInstruction>(shape, new_operands, dimensions(), to_apply()); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index c2d551fb25..ee6e337b6a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -398,20 +398,16 @@ class HloReduceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; - // Returns the number of input arrays (and, consequentially, the number of - // init values) this reduce has. - int64 input_count() const { return operand_count() / 2; } - // Returns the input tensors to be reduced. tensorflow::gtl::ArraySlice<HloInstruction*> inputs() const { return tensorflow::gtl::ArraySlice<HloInstruction*>(operands(), 0, - input_count()); + operand_count() / 2); } // Returns the init values of the reduction. tensorflow::gtl::ArraySlice<HloInstruction*> init_values() const { return tensorflow::gtl::ArraySlice<HloInstruction*>( - operands(), input_count(), operand_count()); + operands(), operand_count() / 2, operand_count()); } private: diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 3f3cb2fa54..f1b29c2559 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -288,13 +288,14 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { - std::vector<const Shape*> operand_shapes; - for (const HloInstruction* operand : reduce->operands()) { - operand_shapes.push_back(&operand->shape()); + if (!ShapeUtil::IsArray(reduce->shape())) { + return InvalidArgument("Variadic reduce is not supported."); } - return CheckShape(reduce, ShapeInference::InferReduceShape( - operand_shapes, reduce->dimensions(), - reduce->to_apply()->ComputeProgramShape())); + return CheckShape( + reduce, + ShapeInference::InferReduceShape( + {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()}, + reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index f1ab83df82..9cd974fd9b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -290,6 +290,10 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; + } else { + TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) || + expected.shape().element_type() == PRED) + << ShapeUtil::HumanString(expected.shape()); } // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. @@ -346,6 +350,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } + TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 60ada58b2a..776f93d9f7 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -203,7 +203,6 @@ enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. ConstantType GetInitValue(const HloComputation& computation) { - // TODO(b/77635120): Add init values, for min, max, and their arg variants. const HloInstruction* const root = computation.root_instruction(); if (computation.num_parameters() != 2 || root->operand_count() != 2 || root->operand(0)->opcode() != HloOpcode::kParameter || @@ -228,10 +227,10 @@ bool NeedsInitValue(const HloUse& use) { const HloInstruction* const instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; - return ((opcode == HloOpcode::kReduceWindow && op_num == 1) || - (opcode == HloOpcode::kSelectAndScatter && op_num == 2) || - (opcode == HloOpcode::kReduce && - op_num >= instruction->operand_count() / 2)); + return ( + ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && + op_num == 1) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2)); } // Generate random values that are constrained to the input_shape minus the |