aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/periodic_resample/ops/array_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/periodic_resample/ops/array_ops.cc')
-rw-r--r--tensorflow/contrib/periodic_resample/ops/array_ops.cc53
1 files changed, 52 insertions, 1 deletions
diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
index 82bd796956..fd38cd09b4 100644
--- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc
+++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
@@ -26,7 +26,42 @@ REGISTER_OP("PeriodicResample")
.Input("values: T")
.Attr("shape: shape")
.Output("output: T")
- .SetShapeFn(shape_inference::ExplicitShape)
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ tensorflow::PartialTensorShape desired_shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
+ shape_inference::ShapeHandle input_tensor_shape = c->input(0);
+ shape_inference::DimensionHandle num_input_elements =
+ c->NumElements(input_tensor_shape);
+ shape_inference::ShapeHandle result_shape_handle;
+ if (!shape_inference::InferenceContext::ValueKnown(num_input_elements)) {
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ desired_shape, &result_shape_handle));
+ } else {
+ const int rank = c->Rank(input_tensor_shape);
+ std::vector<tensorflow::int64> target_dimensions(rank);
+ tensorflow::int64 new_sliced_size = 1;
+ int adjustable_dimension = 0;
+ for (int i = 0; i < rank; ++i) {
+ if (desired_shape.dim_size(i) < 1) {
+ adjustable_dimension = i;
+ } else {
+ target_dimensions[i] = desired_shape.dim_size(i);
+ new_sliced_size *= target_dimensions[i];
+ }
+ }
+ target_dimensions[adjustable_dimension] =
+ shape_inference::InferenceContext::Value(
+ num_input_elements) / new_sliced_size;
+ tensorflow::TensorShape result_shape;
+ for (int i = 0; i < rank; ++i) {
+ result_shape.AddDim(target_dimensions[i]);
+ }
+ TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(
+ result_shape, &result_shape_handle));
+ }
+ c->set_output(0, result_shape_handle);
+ return Status::OK();
+ })
.Doc(R"doc(
Periodically resample elements of a tensor to conform to `shape`.
@@ -101,4 +136,20 @@ output: Periodically resampled tensor that has dimensions specified as in
)doc");
+
+REGISTER_OP("PeriodicResampleOpGrad")
+ .Attr("T: numbertype")
+ .Input("grad: T")
+ .Attr("original_shape: shape")
+ .Attr("desired_shape: shape")
+ .Output("grad_values: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ tensorflow::TensorShape original_shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("original_shape", &original_shape));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(original_shape, &s));
+ c->set_output(0, s);
+ return Status::OK();
+});
+
} // namespace tensorflow