aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/arithmetic.cc
diff options
context:
space:
mode:
authorGravatar Avijit <Avijit.Chakraborty@intel.com>2018-07-25 01:08:01 -0700
committerGravatar Avijit <Avijit.Chakraborty@intel.com>2018-07-25 01:08:01 -0700
commit1cdacb8b10d0b4687387be5fd8be978d68602a1d (patch)
treea2bf88798854a426f073325eb85d85b3ab914418 /tensorflow/compiler/xla/client/lib/arithmetic.cc
parentf88a6f93bee89c610fa8b399d037c7a33c1a0a3e (diff)
parent3f454e4060d855f43eebe0cdc27d8c24f906d430 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/arithmetic.cc')
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc12
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 978fc40f34..de1d785e19 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -94,16 +94,18 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type,
});
}
-XlaComputation CreateScalarAndComputation(XlaBuilder* builder) {
+XlaComputation CreateScalarAndComputation(PrimitiveType type,
+ XlaBuilder* builder) {
return CreateScalarComputation(
- "and", PRED, builder,
+ "and", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
return And(lhs, rhs);
});
}
-XlaComputation CreateScalarOrComputation(XlaBuilder* builder) {
- return CreateScalarComputation("or", PRED, builder,
+XlaComputation CreateScalarOrComputation(PrimitiveType type,
+ XlaBuilder* builder) {
+ return CreateScalarComputation("or", type, builder,
[](XlaBuilder* b, const XlaOp& lhs,
const XlaOp& rhs) { return Or(lhs, rhs); });
}
@@ -112,7 +114,7 @@ XlaOp Any(XlaOp predicates) {
XlaBuilder* builder = predicates.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto f = ConstantR0<bool>(builder, false);
- XlaComputation logical_or = CreateScalarOrComputation(builder);
+ XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
builder->GetShape(predicates));
std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));