aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/arithmetic.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/arithmetic.cc')
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc12
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 978fc40f34..de1d785e19 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -94,16 +94,18 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type,
});
}
-XlaComputation CreateScalarAndComputation(XlaBuilder* builder) {
+XlaComputation CreateScalarAndComputation(PrimitiveType type,
+ XlaBuilder* builder) {
return CreateScalarComputation(
- "and", PRED, builder,
+ "and", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
return And(lhs, rhs);
});
}
-XlaComputation CreateScalarOrComputation(XlaBuilder* builder) {
- return CreateScalarComputation("or", PRED, builder,
+XlaComputation CreateScalarOrComputation(PrimitiveType type,
+ XlaBuilder* builder) {
+ return CreateScalarComputation("or", type, builder,
[](XlaBuilder* b, const XlaOp& lhs,
const XlaOp& rhs) { return Or(lhs, rhs); });
}
@@ -112,7 +114,7 @@ XlaOp Any(XlaOp predicates) {
XlaBuilder* builder = predicates.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto f = ConstantR0<bool>(builder, false);
- XlaComputation logical_or = CreateScalarOrComputation(builder);
+ XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
builder->GetShape(predicates));
std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));