diff options
author | 2018-06-23 17:13:37 +0000 | |
---|---|---|
committer | 2018-06-23 17:13:37 +0000 | |
commit | e32b8e1d505a9b7dcfb70707a77c830271a27fcf (patch) | |
tree | 295e4c9bd3fd51bef8b37c2d6440a2bbf0ad1637 /tensorflow/contrib/image | |
parent | bc00d41b4f96043748374ec58912c9ee90cbb601 (diff) |
Register float16 for ImageProjectiveTransform kernel
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r-- | tensorflow/contrib/image/kernels/image_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/image/kernels/image_ops.h | 25 |
2 files changed, 15 insertions, 12 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index c2e32da133..022e17d139 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -35,6 +35,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; template struct FillProjectiveTransform<CPUDevice, uint8>; template struct FillProjectiveTransform<CPUDevice, int32>; template struct FillProjectiveTransform<CPUDevice, int64>; +template struct FillProjectiveTransform<CPUDevice, Eigen::half>; template struct FillProjectiveTransform<CPUDevice, float>; template struct FillProjectiveTransform<CPUDevice, double>; @@ -99,6 +100,7 @@ class ImageProjectiveTransform : public OpKernel { TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); TF_CALL_int64(REGISTER); +TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h index ad50133061..f1dbd1becc 100644 --- a/tensorflow/contrib/image/kernels/image_ops.h +++ b/tensorflow/contrib/image/kernels/image_ops.h @@ -21,6 +21,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" @@ -105,21 +106,21 @@ class ProjectiveGenerator { // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor) // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor) const float value_yfloor = - (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor), - DenseIndex(x_floor), channel, - fill_value) + - (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor), - DenseIndex(x_ceil), channel, - fill_value); + (x_ceil - x) * static_cast<float>(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_floor), + channel, fill_value)) + + (x - x_floor) * static_cast<float>(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_ceil), + channel, fill_value)); // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil) // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil) const float value_yceil = - (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil), - DenseIndex(x_floor), channel, - fill_value) + - (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil), - DenseIndex(x_ceil), channel, - fill_value); + (x_ceil - x) * static_cast<float>(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_floor), + channel, fill_value)) + + (x - x_floor) * static_cast<float>(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_ceil), + channel, fill_value)); // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor) // + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil) return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil); |