aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image/ops/image_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/image/ops/image_ops.cc')
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc52
1 files changed, 4 insertions, 48 deletions
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index 4c6d8c0d19..68771b3d05 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -19,55 +19,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-namespace {
-
-// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
-// height and width come from the size_tensor.
-Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
- int size_input_idx, DimensionHandle channel_dim) {
- // Verify shape of size input.
- ShapeHandle size;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
- DimensionHandle unused;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
-
- // Get size values from the size tensor.
- const Tensor* size_tensor = c->input_tensor(size_input_idx);
- DimensionHandle width;
- DimensionHandle height;
- if (size_tensor == nullptr) {
- width = c->UnknownDim();
- height = c->UnknownDim();
- } else {
- // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
- if (size_tensor->dtype() != DT_INT32) {
- return errors::InvalidArgument(
- "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
- "but got ",
- DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
- " in ", c->DebugString());
- }
- auto vec = size_tensor->vec<int32>();
- height = c->MakeDim(vec(0));
- width = c->MakeDim(vec(1));
- }
- c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
- return Status::OK();
-}
-
-Status ResizeShapeFn(InferenceContext* c) {
- ShapeHandle input;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
- return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
- c->Dim(input, 3));
-}
-
-} // namespace
-
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
@@ -75,11 +29,13 @@ Status ResizeShapeFn(InferenceContext* c) {
REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype")
.Input("transforms: float32")
- .Input("output_shape: int32")
.Attr("dtype: {uint8, int32, int64, float32, float64}")
.Attr("interpolation: string")
.Output("transformed_images: dtype")
- .SetShapeFn(ResizeShapeFn)
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
.Doc(R"doc(
Applies the given transform to each of the images.