aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-12-04 12:33:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-04 12:45:09 -0800
commit71cd06c608d1cb6fb23f63cf20403b1958965c43 (patch)
tree5118aaf04d6e73f667bddd50c1e1bfd94b6e56a4
parent5917d48293a5582d625f015e4862b2d370b75079 (diff)
[TF:XLA] Fix wrong output of FloorDiv op for DT_HALF values.
PiperOrigin-RevId: 177851804
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc3
-rw-r--r--tensorflow/core/framework/types.cc12
-rw-r--r--tensorflow/core/framework/types.h3
3 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 1de9192432..2436a6074a 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
namespace tensorflow {
namespace {
@@ -75,7 +76,7 @@ static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b,
auto abs_y = b->Abs(y);
auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one));
auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y));
- if (dtype == DT_FLOAT || dtype == DT_DOUBLE) {
+ if (DataTypeIsFloating(dtype)) {
result = b->Floor(result);
}
return result;
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
index 48849f9dda..02b2df448a 100644
--- a/tensorflow/core/framework/types.cc
+++ b/tensorflow/core/framework/types.cc
@@ -306,6 +306,18 @@ bool DataTypeCanUseMemcpy(DataType dt) {
}
}
+bool DataTypeIsFloating(DataType dt) {
+ switch (dt) {
+ case DT_HALF:
+ case DT_BFLOAT16:
+ case DT_FLOAT:
+ case DT_DOUBLE:
+ return true;
+ default:
+ return false;
+ }
+}
+
bool DataTypeIsQuantized(DataType dt) {
switch (dt) {
case DT_QINT8:
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index dc53ed4178..c27a4d4605 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -222,6 +222,9 @@ static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32");
bool DataTypeCanUseMemcpy(DataType dt);
+// Returns true iff 'dt' is a real, non-quantized floating point type.
+bool DataTypeIsFloating(DataType dt);
+
bool DataTypeIsQuantized(DataType dt);
// Is the dtype nonquantized integral?