aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-08-29 10:34:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 10:38:34 -0700
commit150dee25d82589ca109957cc996efbd2a236e044 (patch)
treeef608bb703027596ed47ea4f30fa6eab93389005
parentaca93368a979419360c1fd84b53b1766b19ba81a (diff)
[XLA] Implement variadic reduce in the evaluator. It is currently supported only for the case where all of the inputs have the same element type.
PiperOrigin-RevId: 210746149
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc44
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h19
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h193
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc29
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc13
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc6
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc9
13 files changed, 247 insertions, 109 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 819d324927..ea53287068 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1733,18 +1733,37 @@ XlaOp XlaBuilder::Reduce(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ return Reduce(tensorflow::gtl::ArraySlice<XlaOp>({operand}),
+ tensorflow::gtl::ArraySlice<XlaOp>({init_value}), computation,
+ dimensions_to_reduce);
+}
+
+XlaOp XlaBuilder::Reduce(
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ tensorflow::gtl::ArraySlice<XlaOp> init_values,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
- TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
- TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
- ShapeInference::InferReduceShape(
- {&operand_shape, &init_shape}, dimensions_to_reduce,
- called_program_shape));
+ std::vector<XlaOp> all_operands;
+ all_operands.insert(all_operands.end(), operands.begin(), operands.end());
+ all_operands.insert(all_operands.end(), init_values.begin(),
+ init_values.end());
+
+ std::vector<const Shape*> operand_shape_ptrs;
+ TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
+ GetOperandShapes(all_operands));
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
+
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferReduceShape(
+ operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
for (int64 dim : dimensions_to_reduce) {
instr.add_dimensions(dim);
@@ -1752,8 +1771,7 @@ XlaOp XlaBuilder::Reduce(
AddCalledComputation(computation, &instr);
- return AddInstruction(std::move(instr), HloOpcode::kReduce,
- {operand, init_value});
+ return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
});
}
@@ -2770,6 +2788,16 @@ XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
dimensions_to_reduce);
}
+// Reduces several arrays simultaneously among the provided dimensions, given
+// "computation" as a reduction operator.
+XlaOp Reduce(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
+ tensorflow::gtl::ArraySlice<XlaOp> init_values,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ return builder->Reduce(operands, init_values, computation,
+ dimensions_to_reduce);
+}
+
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation) {
return operand.builder()->ReduceAll(operand, init_value, computation);
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 193d8ed071..9b82cc03b3 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -659,6 +659,13 @@ class XlaBuilder {
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ // Reduces several arrays simultaneously among the provided dimensions, given
+ // "computation" as a reduction operator.
+ XlaOp Reduce(tensorflow::gtl::ArraySlice<XlaOp> operands,
+ tensorflow::gtl::ArraySlice<XlaOp> init_values,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+
// Convenience wrapper around the above that reduces all the dimensions in the
// operand shape.
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
@@ -1249,6 +1256,11 @@ class XlaBuilder {
friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ friend XlaOp Reduce(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ tensorflow::gtl::ArraySlice<XlaOp> init_values,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation);
friend XlaOp ReduceWindow(
@@ -1823,6 +1835,13 @@ XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+// Reduces several arrays simultaneously among the provided dimensions, given
+// "computation" as a reduction operator.
+XlaOp Reduce(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
+ tensorflow::gtl::ArraySlice<XlaOp> init_values,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+
// Convenience wrapper around the above that reduces all the dimensions in the
// operand shape.
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
index 32573ed355..a6f77db3b0 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
@@ -359,6 +359,7 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
hlo->opcode() == HloOpcode::kConditional) {
return Status::OK();
}
+ // TODO(b/112040122): Correctly normalize variadic reduce.
if ((hlo->opcode() == HloOpcode::kSort ||
hlo->opcode() == HloOpcode::kCrossReplicaSum) &&
ShapeUtil::IsTuple(hlo->shape())) {
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 0e12a1ee03..939b5114c3 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -274,15 +274,21 @@ Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
}
Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
- auto arg = reduce->operand(0);
HloComputation* function = reduce->to_apply();
// Compute the cost of the user function.
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
ProcessSubcomputation(function));
// Compute the cost of all elements for this Reduce operation.
- int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) -
- ShapeUtil::ElementsIn(reduce->shape());
+ // This counts the number of times the reduction function is applied, so it
+ // does not need to be multiplied by the number of input tensors - that's
+ // already "priced in" by the sub-computation doing more work.
+ auto arg = reduce->operand(0);
+ auto output_shape = ShapeUtil::IsArray(reduce->shape())
+ ? reduce->shape()
+ : reduce->shape().tuple_shapes(0);
+ int64 reduction_count =
+ ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
for (const auto& property : sub_properties) {
if (property.first != kBytesAccessedKey) {
current_properties_[property.first] = property.second * reduction_count;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 71f91fde93..c25869f87b 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1262,7 +1262,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape());
if (sort_dim != rank - 1) {
return Unimplemented(
- "Trying to support along dimension %d, which is not the last "
+ "Trying to sort along dimension %d, which is not the last "
"dimension",
sort_dim);
}
@@ -1281,6 +1281,22 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
}
}
+Status HloEvaluator::HandleReduce(HloInstruction* reduce) {
+ if (!ShapeUtil::IsTuple(reduce->shape())) {
+ return DefaultAction(reduce);
+ } else {
+ auto first_element_type = reduce->shape().tuple_shapes(0).element_type();
+ for (const auto& tuple_shape : reduce->shape().tuple_shapes()) {
+ if (tuple_shape.element_type() != first_element_type) {
+ return Unimplemented(
+ "Reduce with several outputs that have mixed element types is "
+ "unsupported");
+ }
+ }
+ return reduce->Visit(typed_visitors_.at(first_element_type).get());
+ }
+}
+
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
VLOG(2) << "About to visit HLO: " << hlo->ToString();
return ShapeUtil::ValidateShape(hlo->shape());
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 0ea7089552..980a7fb9fa 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -185,6 +185,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSort(HloInstruction* sort) override;
+ Status HandleReduce(HloInstruction* reduce) override;
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index f682e69ee9..4edcb05f83 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1577,20 +1577,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleSort<ReturnT>(sort);
}
- Status HandleReduce(HloInstruction* reduce) override {
- // TODO(b/112040122): Support variadic reduce.
- if (!ShapeUtil::IsArray(reduce->shape())) {
- return Unimplemented("Variadic reduce is not supported in the Evaluator");
- }
- auto arg = reduce->operand(0);
- auto init_value = reduce->operand(1);
+ Status HandleReduce(HloInstruction* hlo) override {
+ HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
+ int64 num_args = reduce->inputs().size();
+ bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape());
tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
- TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) ==
- ShapeUtil::Rank(arg->shape()) - dimensions.size());
+
+ absl::InlinedVector<const Shape*, 1> operand_shapes;
+ for (const HloInstruction* operand : reduce->operands()) {
+ operand_shapes.push_back(&operand->shape());
+ }
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferReduceShape(
- {&arg->shape(), &init_value->shape()},
+ operand_shapes,
/*dimensions_to_reduce=*/dimensions,
/*to_apply=*/function->ComputeProgramShape()));
TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
@@ -1598,14 +1598,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
- const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
- VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString();
- const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value);
- VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString();
- TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
- auto init_scalar = init_literal.Get<ReturnT>({});
+ absl::InlinedVector<const Literal*, 1> arg_literals(num_args);
+ absl::InlinedVector<const Literal*, 1> init_literals(num_args);
+ for (int64 i = 0; i < num_args; ++i) {
+ arg_literals[i] = &parent_->GetEvaluatedLiteralFor(reduce->inputs()[i]);
+ VLOG(3) << "HandleReduce arg_literal: " << arg_literals[i]->ToString();
+ init_literals[i] =
+ &parent_->GetEvaluatedLiteralFor(reduce->init_values()[i]);
+ VLOG(3) << "HandleReduce init_literal: " << init_literals[i]->ToString();
+ TF_RET_CHECK(ShapeUtil::IsScalar(init_literals[i]->shape()));
+ }
- const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions());
+ // All args and results have the same dimensions, so pick an arbitrary one.
+ const Shape& arg_shape = arg_literals[0]->shape();
+ const Shape& result_shape = ShapeUtil::IsTuple(reduce->shape())
+ ? reduce->shape().tuple_shapes(0)
+ : reduce->shape();
+ const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions());
std::vector<int64> arg_dim_steps(arg_dimensions.size());
std::vector<int64> arg_dim_counts(arg_dimensions.size());
for (const int64 dim : dimensions) {
@@ -1623,63 +1632,109 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = absl::make_unique<Literal>(reduce->shape());
+ absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
+ for (int64 i = 0; i < num_args; ++i) {
+ results[i] = absl::make_unique<Literal>(result_shape);
+ }
+
Status eval_status;
- // For each resulting dimension, calculate and assign computed value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
- ReturnT result_val = init_scalar;
- if (!eval_status.ok()) {
- return result_val;
- }
+ // For each resulting dimension, calculate and assign computed values.
+ // This is really wasteful when num_args > 1, since we re-run the
+ // reduction num_args time. The alternative is to teach Populate() about
+ // tuples, which we should probably do.
+ absl::InlinedVector<ReturnT, 1> init_scalars(num_args);
+ for (int i = 0; i < num_args; ++i) {
+ init_scalars[i] = init_literals[i]->Get<ReturnT>({});
+ }
- std::vector<int64> base(arg_dimensions.size());
- for (int64 i = 0; i < multi_index.size(); ++i) {
- base[result_to_arg_index[i]] = multi_index[i];
- }
+ for (int64 input = 0; input < num_args; ++input) {
+ TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ if (!eval_status.ok()) {
+ return init_scalars[input];
+ }
+ absl::InlinedVector<ReturnT, 1> result_values(init_scalars.begin(),
+ init_scalars.end());
+ std::vector<int64> base(arg_dimensions.size());
+ for (int64 i = 0; i < multi_index.size(); ++i) {
+ base[result_to_arg_index[i]] = multi_index[i];
+ }
- // When the reduction is addition of floats, accumulate in a double
- // for better precision. Also, avoid creating Literals for the
- // intermediate results; it's much faster.
- if (ShapeUtil::ElementIsFloating(init_literal.shape()) &&
- IsScalarAdd(function)) {
- double computed_result = 0;
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
- computed_result += GetAsDouble<ReturnT>(arg_literal, input_index);
+ // When the reduction is addition of floats, accumulate in a double
+ // for better precision. Also, avoid creating Literals for the
+ // intermediate results; it's much faster.
+ if (ShapeUtil::ElementIsFloating(init_literals[0]->shape()) &&
+ IsScalarAdd(function)) {
+ CHECK_EQ(num_args, 1);
+ double computed_result = 0;
+ auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
+ computed_result +=
+ GetAsDouble<ReturnT>(*arg_literals[0], input_index);
+ return true;
+ };
+ ShapeUtil::ForEachIndex(arg_literals[0]->shape(), base,
+ arg_dim_counts, arg_dim_steps, func);
+ return static_cast<ReturnT>(computed_result);
+ }
+ auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index)
+ -> StatusOr<bool> {
+ absl::InlinedVector<ReturnT, 1> arg_values(num_args);
+ for (int64 i = 0; i < num_args; ++i) {
+ arg_values[i] = arg_literals[i]->Get<ReturnT>(input_index);
+ }
+
+ // Evaluate computation with specified literal operands.
+ absl::InlinedVector<std::unique_ptr<Literal>, 1>
+ embedded_operands;
+ for (ReturnT value : result_values) {
+ embedded_operands.push_back(
+ LiteralUtil::CreateR0<ReturnT>(value));
+ }
+ for (ReturnT value : arg_values) {
+ embedded_operands.push_back(
+ LiteralUtil::CreateR0<ReturnT>(value));
+ }
+ absl::InlinedVector<Literal*, 1> embedded_operands_ptrs(
+ embedded_operands.size());
+ std::transform(embedded_operands.begin(), embedded_operands.end(),
+ embedded_operands_ptrs.begin(),
+ [](const std::unique_ptr<Literal>& ptr) {
+ return ptr.get();
+ });
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
+ embedded_evaluator.Evaluate<const Literal*>(
+ *function, embedded_operands_ptrs));
+ // Clear visit states so that we can use the evaluator again on
+ // the same computation.
+ embedded_evaluator.ResetVisitStates();
+ // Assign computed result to result_val.
+ if (!has_tuple_output) {
+ result_values[0] = computed_result->Get<ReturnT>({});
+ } else {
+ for (int64 i = 0; i < num_args; ++i) {
+ result_values[i] = computed_result->Get<ReturnT>(
+ /*multi_index=*/{}, /*shape_index=*/{i});
+ }
+ }
return true;
};
- ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
- arg_dim_steps, func);
- return static_cast<ReturnT>(computed_result);
- }
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index)
- -> StatusOr<bool> {
- auto curr_val = arg_literal.Get<ReturnT>(input_index);
-
- // Evaluate computation with specified literal operands.
- auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(curr_val);
- auto result_val_literal =
- LiteralUtil::CreateR0<ReturnT>(result_val);
-
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
- embedded_evaluator.Evaluate<const Literal*>(
- *function, {result_val_literal.get(),
- curr_val_literal.get()}));
- // Clear visit states so that we can use the evaluator again on
- // the same computation.
- embedded_evaluator.ResetVisitStates();
- // Assign computed result to result_val.
- result_val = computed_result->Get<ReturnT>({});
- return true;
- };
- // Computes one element of the result, reducing all dimensions that
- // contribute to that element.
- eval_status = ShapeUtil::ForEachIndexWithStatus(
- arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func);
- return result_val;
- }));
-
- parent_->evaluated_[reduce] = std::move(result);
+ // Computes one element of the result, reducing all dimensions that
+ // contribute to that element.
+ eval_status = ShapeUtil::ForEachIndexWithStatus(
+ arg_shape, base, arg_dim_counts, arg_dim_steps, func);
+ return result_values[input];
+ }));
+ }
+ if (!has_tuple_output) {
+ parent_->evaluated_[reduce] = std::move(results[0]);
+ } else {
+ auto tuple_result = absl::make_unique<Literal>(reduce->shape());
+ for (int64 i = 0; i < num_args; ++i) {
+ TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
+ }
+ parent_->evaluated_[reduce] = std::move(tuple_result);
+ }
return eval_status;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index ed4e159910..8a497e6edf 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -158,16 +158,26 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0));
break;
case HloOpcode::kReduce:
- TF_RET_CHECK(proto.operand_ids_size() == 2)
- << "Reduce instruction should have 2 operands but sees "
+ TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
+ << "Reduce instruction should have an even number of operands but "
+ "sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Reduce instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
- instruction = CreateReduce(proto.shape(), operands(0), operands(1),
- std::vector<int64>(proto.dimensions().begin(),
- proto.dimensions().end()),
- computations(0));
+ {
+ const auto reduce_operands = all_operands();
+ tensorflow::gtl::ArraySlice<HloInstruction*> inputs(
+ reduce_operands, 0, reduce_operands.size() / 2);
+ tensorflow::gtl::ArraySlice<HloInstruction*> init_values(
+ reduce_operands, reduce_operands.size() / 2,
+ reduce_operands.size());
+ instruction =
+ CreateReduce(proto.shape(), inputs, init_values,
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()),
+ computations(0));
+ }
break;
case HloOpcode::kSort: {
TF_RET_CHECK(proto.operand_ids_size() == 1 ||
@@ -2749,10 +2759,13 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
case HloOpcode::kTranspose:
return UseKind::kUsePermutingElements;
case HloOpcode::kPad:
- case HloOpcode::kReduce:
// Pad reuses the padding value but not the padded array elements.
- // Reduce reuses the init value but not the operand array elements.
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
+ case HloOpcode::kReduce:
+ // Reduce reuses the init values but not the operand array elements.
+ return i >= Cast<HloReduceInstruction>(this)->input_count()
+ ? UseKind::kReuse
+ : UseKind::kUsePermutingElements;
case HloOpcode::kFusion:
// Uses the memoizing, recursive computation defined above.
return FusionReusesParamElements::Compute(i, *fused_expression_root());
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index ffc74cfedd..0b7f741d73 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -586,7 +586,7 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 2);
+ CHECK_EQ(new_operands.size() % 2, 0);
return absl::make_unique<HloReduceInstruction>(shape, new_operands,
dimensions(), to_apply());
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index ee6e337b6a..c2d551fb25 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -398,16 +398,20 @@ class HloReduceInstruction : public HloInstruction {
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
+ // Returns the number of input arrays (and, consequentially, the number of
+ // init values) this reduce has.
+ int64 input_count() const { return operand_count() / 2; }
+
// Returns the input tensors to be reduced.
tensorflow::gtl::ArraySlice<HloInstruction*> inputs() const {
return tensorflow::gtl::ArraySlice<HloInstruction*>(operands(), 0,
- operand_count() / 2);
+ input_count());
}
// Returns the init values of the reduction.
tensorflow::gtl::ArraySlice<HloInstruction*> init_values() const {
return tensorflow::gtl::ArraySlice<HloInstruction*>(
- operands(), operand_count() / 2, operand_count());
+ operands(), input_count(), operand_count());
}
private:
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index f1b29c2559..3f3cb2fa54 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -288,14 +288,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
}
Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
- if (!ShapeUtil::IsArray(reduce->shape())) {
- return InvalidArgument("Variadic reduce is not supported.");
+ std::vector<const Shape*> operand_shapes;
+ for (const HloInstruction* operand : reduce->operands()) {
+ operand_shapes.push_back(&operand->shape());
}
- return CheckShape(
- reduce,
- ShapeInference::InferReduceShape(
- {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()},
- reduce->dimensions(), reduce->to_apply()->ComputeProgramShape()));
+ return CheckShape(reduce, ShapeInference::InferReduceShape(
+ operand_shapes, reduce->dimensions(),
+ reduce->to_apply()->ComputeProgramShape()));
}
Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 9cd974fd9b..f1ab83df82 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -290,10 +290,6 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
if (ShapeUtil::ElementIsFloating(expected.shape()) ||
ShapeUtil::ElementIsComplex(expected.shape())) {
LOG(WARNING) << "performing exact comparison of floating point numbers";
- } else {
- TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) ||
- expected.shape().element_type() == PRED)
- << ShapeUtil::HumanString(expected.shape());
}
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
@@ -350,8 +346,6 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
}
- TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) ||
- ShapeUtil::ElementIsComplex(expected.shape()));
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 776f93d9f7..60ada58b2a 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -203,6 +203,7 @@ enum class ConstantType { kUnknown, kZero, kOne };
// Return the constant type required by this computation, if known.
ConstantType GetInitValue(const HloComputation& computation) {
+ // TODO(b/77635120): Add init values, for min, max, and their arg variants.
const HloInstruction* const root = computation.root_instruction();
if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
root->operand(0)->opcode() != HloOpcode::kParameter ||
@@ -227,10 +228,10 @@ bool NeedsInitValue(const HloUse& use) {
const HloInstruction* const instruction = use.instruction;
const HloOpcode opcode = instruction->opcode();
const int64 op_num = use.operand_number;
- return (
- ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) &&
- op_num == 1) ||
- (opcode == HloOpcode::kSelectAndScatter && op_num == 2));
+ return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
+ (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
+ (opcode == HloOpcode::kReduce &&
+ op_num >= instruction->operand_count() / 2));
}
// Generate random values that are constrained to the input_shape minus the