aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-06-23 17:13:37 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-06-23 17:13:37 +0000
commite32b8e1d505a9b7dcfb70707a77c830271a27fcf (patch)
tree295e4c9bd3fd51bef8b37c2d6440a2bbf0ad1637 /tensorflow/contrib/image
parentbc00d41b4f96043748374ec58912c9ee90cbb601 (diff)
Register float16 for ImageProjectiveTransform kernel
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc2
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.h25
2 files changed, 15 insertions, 12 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index c2e32da133..022e17d139 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -35,6 +35,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
template struct FillProjectiveTransform<CPUDevice, uint8>;
template struct FillProjectiveTransform<CPUDevice, int32>;
template struct FillProjectiveTransform<CPUDevice, int64>;
+template struct FillProjectiveTransform<CPUDevice, Eigen::half>;
template struct FillProjectiveTransform<CPUDevice, float>;
template struct FillProjectiveTransform<CPUDevice, double>;
@@ -99,6 +100,7 @@ class ImageProjectiveTransform : public OpKernel {
TF_CALL_uint8(REGISTER);
TF_CALL_int32(REGISTER);
TF_CALL_int64(REGISTER);
+TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h
index ad50133061..f1dbd1becc 100644
--- a/tensorflow/contrib/image/kernels/image_ops.h
+++ b/tensorflow/contrib/image/kernels/image_ops.h
@@ -21,6 +21,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
@@ -105,21 +106,21 @@ class ProjectiveGenerator {
// f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
const float value_yfloor =
- (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
- DenseIndex(x_floor), channel,
- fill_value) +
- (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
- DenseIndex(x_ceil), channel,
- fill_value);
+ (x_ceil - x) * static_cast<float>(read_with_fill_value(
+ batch, DenseIndex(y_floor), DenseIndex(x_floor),
+ channel, fill_value)) +
+ (x - x_floor) * static_cast<float>(read_with_fill_value(
+ batch, DenseIndex(y_floor), DenseIndex(x_ceil),
+ channel, fill_value));
// f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
const float value_yceil =
- (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
- DenseIndex(x_floor), channel,
- fill_value) +
- (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
- DenseIndex(x_ceil), channel,
- fill_value);
+ (x_ceil - x) * static_cast<float>(read_with_fill_value(
+ batch, DenseIndex(y_ceil), DenseIndex(x_floor),
+ channel, fill_value)) +
+ (x - x_floor) * static_cast<float>(read_with_fill_value(
+ batch, DenseIndex(y_ceil), DenseIndex(x_ceil),
+ channel, fill_value));
// f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
// + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);