diff options
author | 2017-12-04 12:33:58 -0800 | |
---|---|---|
committer | 2017-12-04 12:45:09 -0800 | |
commit | 71cd06c608d1cb6fb23f63cf20403b1958965c43 (patch) | |
tree | 5118aaf04d6e73f667bddd50c1e1bfd94b6e56a4 | |
parent | 5917d48293a5582d625f015e4862b2d370b75079 (diff) |
[TF:XLA] Fix wrong output of FloorDiv op for DT_HALF values.
PiperOrigin-RevId: 177851804
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/binary_ops.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/framework/types.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/framework/types.h | 3 |
3 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 1de9192432..2436a6074a 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" namespace tensorflow { namespace { @@ -75,7 +76,7 @@ static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b, auto abs_y = b->Abs(y); auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); - if (dtype == DT_FLOAT || dtype == DT_DOUBLE) { + if (DataTypeIsFloating(dtype)) { result = b->Floor(result); } return result; diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc index 48849f9dda..02b2df448a 100644 --- a/tensorflow/core/framework/types.cc +++ b/tensorflow/core/framework/types.cc @@ -306,6 +306,18 @@ bool DataTypeCanUseMemcpy(DataType dt) { } } +bool DataTypeIsFloating(DataType dt) { + switch (dt) { + case DT_HALF: + case DT_BFLOAT16: + case DT_FLOAT: + case DT_DOUBLE: + return true; + default: + return false; + } +} + bool DataTypeIsQuantized(DataType dt) { switch (dt) { case DT_QINT8: diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index dc53ed4178..c27a4d4605 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -222,6 +222,9 @@ static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32"); bool DataTypeCanUseMemcpy(DataType dt); +// Returns true iff 'dt' is a real, non-quantized floating point type. +bool DataTypeIsFloating(DataType dt); + bool DataTypeIsQuantized(DataType dt); // Is the dtype nonquantized integral? |