diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/arithmetic.cc')
-rw-r--r-- | tensorflow/compiler/xla/client/lib/arithmetic.cc | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 978fc40f34..de1d785e19 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -94,16 +94,18 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type, }); } -XlaComputation CreateScalarAndComputation(XlaBuilder* builder) { +XlaComputation CreateScalarAndComputation(PrimitiveType type, + XlaBuilder* builder) { return CreateScalarComputation( - "and", PRED, builder, + "and", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { return And(lhs, rhs); }); } -XlaComputation CreateScalarOrComputation(XlaBuilder* builder) { - return CreateScalarComputation("or", PRED, builder, +XlaComputation CreateScalarOrComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation("or", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { return Or(lhs, rhs); }); } @@ -112,7 +114,7 @@ XlaOp Any(XlaOp predicates) { XlaBuilder* builder = predicates.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { auto f = ConstantR0<bool>(builder, false); - XlaComputation logical_or = CreateScalarOrComputation(builder); + XlaComputation logical_or = CreateScalarOrComputation(PRED, builder); TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, builder->GetShape(predicates)); std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape)); |