aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-09-14 01:24:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-14 01:29:01 -0700
commit3a2276ced02b217596080fb34654d2dce5069f81 (patch)
tree02a91869529186e4b634c764d68373be4a290579 /tensorflow/compiler/tf2xla
parentb43aeb053ec440ea5205a09c229339c10a962af4 (diff)
[XLA:TF] Make FloorDiv not crash on unsigned types
FloorDiv (which corresponds to the // operator in python) supports uint8 and uint16 (but not uint32) in TF. Using xla::Abs on unsigned types throws an error, but the rounding logic is trivial for unsigned types so just do a plain Div. This isn't tested yet because we don't have any targets supporting uint8 or uint16 yet. PiperOrigin-RevId: 212946132
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index df17da4c1c..0d9a768a6f 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -66,6 +66,9 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ if (DataTypeIsUnsigned(dtype)) {
+ return xla::Div(x, y);
+ }
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));