aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image/ops/image_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/image/ops/image_ops.cc')
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc57
1 files changed, 50 insertions, 7 deletions
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index e59f1bf844..4969ac58f9 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -19,23 +19,66 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
+namespace {
+
+// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
+// height and width come from the size_tensor.
+Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
+ int size_input_idx, DimensionHandle channel_dim) {
+ // Verify shape of size input.
+ ShapeHandle size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
+
+ // Get size values from the size tensor.
+ const Tensor* size_tensor = c->input_tensor(size_input_idx);
+ DimensionHandle width;
+ DimensionHandle height;
+ if (size_tensor == nullptr) {
+ width = c->UnknownDim();
+ height = c->UnknownDim();
+ } else {
+ // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
+ if (size_tensor->dtype() != DT_INT32) {
+ return errors::InvalidArgument(
+ "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
+ "but got ",
+ DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
+ " in ", c->DebugString());
+ }
+ auto vec = size_tensor->vec<int32>();
+ height = c->MakeDim(vec(0));
+ width = c->MakeDim(vec(1));
+ }
+ c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
+ return Status::OK();
+}
+
+// TODO(qyu): Move this to core/framework/common_shape_fns.h
+Status ResizeShapeFn(InferenceContext* c) {
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+ return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
+ c->Dim(input, 3));
+}
+
+} // namespace
+
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
-// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
-// implement "same" and "valid" modes in the Python function.
REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype")
.Input("transforms: float32")
+ .Input("output_shape: int32")
.Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
.Attr("interpolation: string")
.Output("transformed_images: dtype")
- .SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->input(0));
- return Status::OK();
- })
+ .SetShapeFn(ResizeShapeFn)
.Doc(R"doc(
Applies the given transform to each of the images.
@@ -49,7 +92,7 @@ If one row of `transforms` is `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps
the *output* point `(x, y)` to a transformed *input* point
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where
`k = c0 x + c1 y + 1`. If the transformed point lays outside of the input
-image, the output pixel is set to 0. The output is the same size as the input,
+image, the output pixel is set to 0.
images: 4D `Tensor`, input image(s) in NHWC format.
transforms: 2D `Tensor`, projective transform(s) to apply to the image(s).