diff options
Diffstat (limited to 'tensorflow/contrib/periodic_resample/ops/array_ops.cc')
-rw-r--r-- | tensorflow/contrib/periodic_resample/ops/array_ops.cc | 53 |
1 files changed, 52 insertions, 1 deletions
diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc index 82bd796956..fd38cd09b4 100644 --- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc +++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc @@ -26,7 +26,42 @@ REGISTER_OP("PeriodicResample") .Input("values: T") .Attr("shape: shape") .Output("output: T") - .SetShapeFn(shape_inference::ExplicitShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + tensorflow::PartialTensorShape desired_shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape)); + shape_inference::ShapeHandle input_tensor_shape = c->input(0); + shape_inference::DimensionHandle num_input_elements = + c->NumElements(input_tensor_shape); + shape_inference::ShapeHandle result_shape_handle; + if (!shape_inference::InferenceContext::ValueKnown(num_input_elements)) { + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + desired_shape, &result_shape_handle)); + } else { + const int rank = c->Rank(input_tensor_shape); + std::vector<tensorflow::int64> target_dimensions(rank); + tensorflow::int64 new_sliced_size = 1; + int adjustable_dimension = 0; + for (int i = 0; i < rank; ++i) { + if (desired_shape.dim_size(i) < 1) { + adjustable_dimension = i; + } else { + target_dimensions[i] = desired_shape.dim_size(i); + new_sliced_size *= target_dimensions[i]; + } + } + target_dimensions[adjustable_dimension] = + shape_inference::InferenceContext::Value( + num_input_elements) / new_sliced_size; + tensorflow::TensorShape result_shape; + for (int i = 0; i < rank; ++i) { + result_shape.AddDim(target_dimensions[i]); + } + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape( + result_shape, &result_shape_handle)); + } + c->set_output(0, result_shape_handle); + return Status::OK(); + }) .Doc(R"doc( Periodically resample elements of a tensor to conform to `shape`. @@ -101,4 +136,20 @@ output: Periodically resampled tensor that has dimensions specified as in )doc"); + +REGISTER_OP("PeriodicResampleOpGrad") + .Attr("T: numbertype") + .Input("grad: T") + .Attr("original_shape: shape") + .Attr("desired_shape: shape") + .Output("grad_values: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + tensorflow::TensorShape original_shape; + TF_RETURN_IF_ERROR(c->GetAttr("original_shape", &original_shape)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(original_shape, &s)); + c->set_output(0, s); + return Status::OK(); +}); + } // namespace tensorflow |