aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/periodic_resample
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-18 09:57:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 09:59:59 -0700
commite80732c9895d1283af9b98d6277ad1a1015e2e9a (patch)
tree14895657394f9cdfed8435460e37fe89a45ba599 /tensorflow/contrib/periodic_resample
parent8ecf506fb8464dd273ce59f512f5e20d37dd5cfd (diff)
Merge changes from github.
PiperOrigin-RevId: 201011811
Diffstat (limited to 'tensorflow/contrib/periodic_resample')
-rw-r--r--tensorflow/contrib/periodic_resample/BUILD20
-rw-r--r--tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc5
-rw-r--r--tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h415
-rw-r--r--tensorflow/contrib/periodic_resample/ops/array_ops.cc53
-rw-r--r--tensorflow/contrib/periodic_resample/ops/array_ops_test.cc41
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py27
-rw-r--r--tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py8
7 files changed, 449 insertions, 120 deletions
diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD
index 6ca7fe8b6e..aad1ca04c5 100644
--- a/tensorflow/contrib/periodic_resample/BUILD
+++ b/tensorflow/contrib/periodic_resample/BUILD
@@ -6,12 +6,13 @@ exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
- "py_test",
+ "tf_cc_test",
"tf_gen_op_libs",
"tf_custom_op_library",
"tf_custom_op_py_library",
"tf_gen_op_wrapper_py",
)
+load("//tensorflow:tensorflow.bzl", "py_test")
cc_library(
name = "all_ops",
@@ -84,6 +85,23 @@ py_test(
":init_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradient_checker",
+ ],
+)
+
+tf_cc_test(
+ name = "periodic_resample_op_cc_test",
+ size = "small",
+ srcs = [
+ "ops/array_ops_test.cc",
+ ],
+ deps = [
+ ":all_ops",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_proto",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
],
)
diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc
index e18923c8aa..514689cf45 100644
--- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc
+++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc
@@ -22,4 +22,9 @@ namespace tensorflow {
REGISTER_KERNEL_BUILDER(Name("PeriodicResample").Device(DEVICE_CPU),
PeriodicResampleOp);
+
+REGISTER_KERNEL_BUILDER(Name("PeriodicResampleOpGrad")
+ .Device(DEVICE_CPU),
+ PeriodicResampleOpGrad);
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
index 3ab588c458..42fba81a5c 100644
--- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
+++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
@@ -25,92 +25,202 @@
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/work_sharder.h"
namespace {
-template <class IndexVecT, class IndexT>
-IndexT compute_input_index(
- IndexVecT* target_dimensions, const IndexT& output_index,
- const IndexVecT& original_dimensions, const int& adjustable_dimension,
- const std::vector<tensorflow::int64>& dimension_ceiling,
- const std::vector<tensorflow::int64>& cumulative_dimensions, IndexT* result,
- std::vector<IndexT>* output_indices, const int& rank) {
- *result = 0;
- output_indices->clear();
+// Computes input tensor index for given output index during forward
+// propagation through periodic_resample operation.
+class InputIndexer {
+ public:
+ InputIndexer(const std::vector<tensorflow::int64>& output_dimensions,
+ const tensorflow::TensorShape& input_shape,
+ int adjustable_dimension)
+ : output_dimensions_(output_dimensions),
+ adjustable_dimension_(adjustable_dimension),
+ rank_(input_shape.dims()),
+ linear_output_index_(0),
+ linear_input_index_(0),
+ adjustable_dimension_carriage_sum_(0) {
+ auto input_dimensions = TensorShapeToVector(input_shape);
+ // factors by which input_dimensions increases/decreases w.r.t.
+ // output_dimensions
+ dimension_ceiling_ =
+ ComputeDimensionCeiling(output_dimensions, input_dimensions);
+ cumulative_dimensions_ = ComputeCumulativeDimensions();
+
+ output_indices_.resize(output_dimensions_.size());
+ input_indices_.resize(output_dimensions_.size());
+
+ // Compute index_factors
+ index_factors_.resize(rank_);
+ tensorflow::int64 last_index_factor = 1;
+ for (auto r = rank_ - 1; r >= 0; --r) {
+ index_factors_[r] = last_index_factor;
+ last_index_factor *= input_dimensions[r];
+ }
+ }
+
+ tensorflow::int64 linear_input_index() const { return linear_input_index_; }
+
+ void MoveToOutputIndex(tensorflow::int64 output_index);
+ void IncrementOutputIndex();
+
+ private:
+ void RecomputeInputAdjustableDimensionIndex() {
+ tensorflow::int64 index = adjustable_dimension_carriage_sum_;
+ index *= output_dimensions_[adjustable_dimension_];
+ index += output_indices_[adjustable_dimension_];
+ input_indices_[adjustable_dimension_] = index;
+ }
+
+ std::vector<tensorflow::int64> TensorShapeToVector(
+ const tensorflow::TensorShape& tensor_shape);
+
+ std::vector<tensorflow::int64> ComputeDimensionCeiling(
+ const std::vector<tensorflow::int64>& output_dimensions,
+ const std::vector<tensorflow::int64>& input_dimensions);
+
+ std::vector<tensorflow::int64> ComputeCumulativeDimensions();
+
+ const std::vector<tensorflow::int64> output_dimensions_;
+ std::vector<tensorflow::int64> dimension_ceiling_;
+ std::vector<tensorflow::int64> index_factors_;
+ std::vector<tensorflow::int64> cumulative_dimensions_;
+ std::vector<tensorflow::int64> output_indices_;
+ std::vector<tensorflow::int64> input_indices_;
+
+ const int adjustable_dimension_;
+ const int rank_;
+ tensorflow::int64 linear_output_index_;
+ tensorflow::int64 linear_input_index_;
+ tensorflow::int64 adjustable_dimension_carriage_sum_;
+};
+
+void InputIndexer::MoveToOutputIndex(tensorflow::int64 output_index) {
+ linear_output_index_ = output_index;
+ linear_input_index_ = 0;
// un-rasterize the output index
auto last_reduced_i = output_index;
- for (auto r = rank - 1; r >= 0; --r) {
- (*output_indices)[r] = last_reduced_i % (*target_dimensions)[r];
+ for (auto r = rank_ - 1; r >= 0; --r) {
+ output_indices_[r] = last_reduced_i % output_dimensions_[r];
last_reduced_i =
- (last_reduced_i - (*output_indices)[r]) / (*target_dimensions)[r];
+ (last_reduced_i - output_indices_[r]) / output_dimensions_[r];
}
+ tensorflow::int64 carriage_sum = 0;
+ for (int qi = 0; qi < rank_; ++qi) {
+ if (qi == adjustable_dimension_) continue;
+ carriage_sum += cumulative_dimensions_[qi] *
+ (output_indices_[qi] % dimension_ceiling_[qi]);
+ }
+ adjustable_dimension_carriage_sum_ = carriage_sum;
+
// rasterize the input index
- IndexT last_index_factor = 1;
- for (auto r = rank - 1; r >= 0; --r) {
- IndexT index = 0;
- if (r != adjustable_dimension)
- index = (*output_indices)[r] / dimension_ceiling[r];
- else {
- for (int qi = 0; qi < rank; ++qi) {
- if (qi == adjustable_dimension) continue;
- index += cumulative_dimensions[qi] *
- ((*output_indices)[qi] % dimension_ceiling[qi]);
- }
- index *= (*target_dimensions)[adjustable_dimension];
- index += (*output_indices)[r];
+ for (auto r = rank_ - 1; r >= 0; --r) {
+ if (r != adjustable_dimension_) {
+ input_indices_[r] = output_indices_[r] / dimension_ceiling_[r];
+ } else {
+ RecomputeInputAdjustableDimensionIndex();
}
- *result += last_index_factor * index;
- last_index_factor *= original_dimensions[r];
}
+ for (auto r = rank_ - 1; r >= 0; --r) {
+ linear_input_index_ += index_factors_[r] * input_indices_[r];
+ }
+}
+
+void InputIndexer::IncrementOutputIndex() {
+ linear_output_index_++;
+ for (auto r = rank_ - 1; r >= 0; --r) {
+ auto old_carriage_sum_increment =
+ cumulative_dimensions_[r] *
+ (output_indices_[r] % dimension_ceiling_[r]);
+ output_indices_[r] = (output_indices_[r] + 1) % output_dimensions_[r];
+ if (r != adjustable_dimension_) {
+ auto new_input_index = output_indices_[r] / dimension_ceiling_[r];
+ linear_input_index_ +=
+ (new_input_index - input_indices_[r]) * index_factors_[r];
+
+ input_indices_[r] = new_input_index;
+
+ auto new_carriage_sum_increment =
+ cumulative_dimensions_[r] *
+ (output_indices_[r] % dimension_ceiling_[r]);
- return *result;
+ adjustable_dimension_carriage_sum_ = adjustable_dimension_carriage_sum_ -
+ old_carriage_sum_increment +
+ new_carriage_sum_increment;
+ }
+
+ if (output_indices_[r] != 0) {
+ // No more carries to higher indices.
+ break;
+ }
+ }
+ auto old_adjustable_dimension_input_index =
+ input_indices_[adjustable_dimension_];
+ RecomputeInputAdjustableDimensionIndex();
+ linear_input_index_ += (input_indices_[adjustable_dimension_] -
+ old_adjustable_dimension_input_index) *
+ index_factors_[adjustable_dimension_];
}
-template <class InputDataT,
- class IndexVecT> // both types are needed here b/c IndexVecT and
- // InputDataT are not related
- void
- fill_periodic_tensor(
- tensorflow::OpKernelContext* context,
- const IndexVecT& desired_shape,
- const tensorflow::Tensor& input_tensor) {
- // input is a strided array (last index is fastest, C-ordered)
- auto input = input_tensor.flat<InputDataT>();
- const int rank = input_tensor.dims();
- // original and target dimensions
- std::vector<tensorflow::int64> original_dimensions(rank),
- target_dimensions(rank);
- tensorflow::int64 total_size(input_tensor.NumElements()), new_sliced_size(1);
- // factors by which original_dimensions increases/decreases w.r.t.
- // target_dimensions
- std::vector<tensorflow::int64> dimension_ceiling(rank),
- cumulative_dimensions(rank);
- // index of adjustable dimension
- int adjustable_dimension;
- tensorflow::TensorShape output_shape;
+std::vector<tensorflow::int64> InputIndexer::TensorShapeToVector(
+ const tensorflow::TensorShape& tensor_shape) {
+ std::vector<tensorflow::int64> result(tensor_shape.dims());
+ int count = 0;
+ for (const auto dim_info : tensor_shape) {
+ result[count] = dim_info.size;
+ ++count;
+ }
+ return result;
+}
- // requires that the rank of the input tensor and length of the desired shape
- // are equal
- OP_REQUIRES(context, rank == desired_shape.size(),
- tensorflow::errors::InvalidArgument(
- "periodic_resample expects the rank of the input tensor, ",
- rank, ", to be the same as the length of the desired shape, ",
- desired_shape.size(), "."));
+std::vector<tensorflow::int64> InputIndexer::ComputeDimensionCeiling(
+ const std::vector<tensorflow::int64>& output_dimensions,
+ const std::vector<tensorflow::int64>& input_dimensions) {
+ std::vector<tensorflow::int64> dimension_ceiling(input_dimensions.size());
+ for (size_t i = 0; i < input_dimensions.size(); ++i) {
+ dimension_ceiling[i] = (output_dimensions[i] + input_dimensions[i] - 1) /
+ input_dimensions[i];
+ }
+ return dimension_ceiling;
+}
- bool found = false;
- const auto& input_tensor_shape = input_tensor.shape();
+std::vector<tensorflow::int64> InputIndexer::ComputeCumulativeDimensions() {
+ std::vector<tensorflow::int64> cumulative_dimensions(rank_);
+ int count = 0;
+ for (int i = 0; i < rank_; ++i) {
+ if (count == 0) {
+ cumulative_dimensions[count] = 1;
+ } else {
+ cumulative_dimensions[count] =
+ cumulative_dimensions[count - 1] * dimension_ceiling_[count - 1];
+ }
+ ++count;
+ }
+ return cumulative_dimensions;
+}
+template <typename IndexVecT>
+void process_desired_shape(tensorflow::OpKernelContext* context,
+ const tensorflow::TensorShape& input_tensor_shape,
+ const IndexVecT& desired_shape,
+ int* adjustable_dimension,
+ std::vector<tensorflow::int64>* target_dimensions,
+ tensorflow::int64* output_size) {
+ tensorflow::int64 new_sliced_size = 1;
+ bool found = false;
+ const int rank = input_tensor_shape.dims();
for (int i = 0; i < rank; ++i) {
- // if (desired_shape(i) < 1) {
if (desired_shape[i] < 1) {
// only one index can be adjustable
OP_REQUIRES(context, !found,
tensorflow::errors::InvalidArgument(
"periodic_resample expects only "
"one index to be marked as adjustable."));
- adjustable_dimension = i;
+ *adjustable_dimension = i;
found = true;
} else {
OP_REQUIRES(
@@ -122,9 +232,8 @@ template <class InputDataT,
i, " input tensor has size ", input_tensor_shape.dim_size(i),
", desired shape has size ", desired_shape[i], "."));
- // target_dimensions[i] = desired_shape(i);
- target_dimensions[i] = desired_shape[i];
- new_sliced_size *= target_dimensions[i];
+ (*target_dimensions)[i] = desired_shape[i];
+ new_sliced_size *= (*target_dimensions)[i];
}
}
// at least one index needs to be adjustable
@@ -132,26 +241,50 @@ template <class InputDataT,
tensorflow::errors::InvalidArgument(
"periodic_resample expects at least "
"one index to be marked as adjustable."));
+ (*target_dimensions)[*adjustable_dimension] =
+ input_tensor_shape.num_elements() / new_sliced_size;
- int count = 0;
- for (const auto dim_info : input_tensor.shape()) {
- original_dimensions[count] = dim_info.size;
- ++count;
- }
+ *output_size = new_sliced_size * (*target_dimensions)[*adjustable_dimension];
+}
- target_dimensions[adjustable_dimension] = total_size / new_sliced_size;
+// Heuristic number based on measurements on
+// Intel(R) Core(TM) i7-4930K CPU @ 3.40GHz
+const tensorflow::int64 costPerFillIndex = 35;
- count = 0;
- for (int i = 0; i < input_tensor.shape().dims(); ++i) {
- dimension_ceiling[count] = tensorflow::int64(std::ceil(
- float(target_dimensions[count]) / float(original_dimensions[count])));
- if (count == 0)
- cumulative_dimensions[count] = 1;
- else
- cumulative_dimensions[count] =
- cumulative_dimensions[count - 1] * dimension_ceiling[count - 1];
- ++count;
- }
+enum class Mode {
+ kForward,
+ kGradient
+};
+
+// Computes either periodic_resample operation output or gradients for it,
+// depending on |mode|.
+// |original_shape| is always shape of input to periodic_resample operation.
+// |source_tensor| is either source for periodic_resample (for forward mode)
+// or gradients tensor.
+// |desired_shape| is always shape, provided by user, to which forward
+// propagation attempts resample input tensor.
+template <class InputDataT, Mode mode>
+void
+do_periodic_resample_op(tensorflow::OpKernelContext* context,
+ const tensorflow::TensorShape& original_shape,
+ const tensorflow::PartialTensorShape& desired_shape,
+ const tensorflow::Tensor& source_tensor) {
+ const int rank = source_tensor.dims();
+
+ // requires that the rank of the input tensor and length of the desired shape
+ // are equal
+ OP_REQUIRES(context, rank == desired_shape.dims(),
+ tensorflow::errors::InvalidArgument(
+ "periodic_resample expects the rank of the input tensor, ",
+ rank, ", to be the same as the length of the desired shape, ",
+ desired_shape.dims(), "."));
+
+ std::vector<tensorflow::int64> target_dimensions(rank);
+ tensorflow::int64 new_size = 0;
+ // index of adjustable dimension
+ int adjustable_dimension = 0;
+ process_desired_shape(context, original_shape, desired_shape.dim_sizes(),
+ &adjustable_dimension, &target_dimensions, &new_size);
// ensure that the new dimension is greater than zero
OP_REQUIRES(context, target_dimensions[adjustable_dimension] > 0,
@@ -160,11 +293,14 @@ template <class InputDataT,
"adjustable dimension, ",
adjustable_dimension, ", isn't greater than zero, ",
target_dimensions[adjustable_dimension], "."));
- for (int i = 0; i < rank; ++i) {
- output_shape.AddDim(target_dimensions[i]);
+ tensorflow::TensorShape output_shape;
+ if (mode == Mode::kForward) {
+ for (int i = 0; i < rank; ++i) {
+ output_shape.AddDim(target_dimensions[i]);
+ }
+ } else {
+ output_shape = original_shape;
}
- const auto new_size =
- new_sliced_size * target_dimensions[adjustable_dimension];
// Create an output tensor and attach it to the current context
tensorflow::Tensor* output_tensor = nullptr;
@@ -172,47 +308,73 @@ template <class InputDataT,
context->allocate_output(0, output_shape, &output_tensor));
auto output = output_tensor->flat<InputDataT>();
- // memory is allocated for these variables outside the inner loop for
- // efficiency (although, I could create a separate class scope for
- // this purpose instead)
- tensorflow::int64 result = 0;
- std::vector<tensorflow::int64> output_indices(target_dimensions.size());
+ // input is a strided array (last index is fastest, C-ordered)
+ auto input = source_tensor.flat<InputDataT>();
// Fill output tensor with periodically resampled input tensor values
- for (tensorflow::int64 output_index = 0; output_index < new_size;
- ++output_index) {
- output(output_index) = input(compute_input_index(
- &target_dimensions, output_index, original_dimensions,
- adjustable_dimension, dimension_ceiling, cumulative_dimensions, &result,
- &output_indices, rank));
- }
+ InputIndexer input_indexer(target_dimensions, original_shape,
+ adjustable_dimension);
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+ auto fill_output_tensor = [&input_indexer, &output, &input](
+ tensorflow::int64 start, tensorflow::int64 limit) {
+ InputIndexer local_indexer(input_indexer);
+ local_indexer.MoveToOutputIndex(start);
+ for (tensorflow::int64 output_index = start; output_index < limit;
+ ++output_index) {
+ if (mode == Mode::kForward) {
+ output(output_index) = input(local_indexer.linear_input_index());
+ } else {
+ output(local_indexer.linear_input_index()) = input(output_index);
+ }
+ local_indexer.IncrementOutputIndex();
+ }
+ };
+ ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers,
+ new_size, costPerFillIndex, fill_output_tensor);
}
+#define DATA_TYPE_SWITCH(data_type, context, CASE) \
+ switch (data_type) { \
+ CASE(float) \
+ CASE(double) \
+ CASE(tensorflow::int32) \
+ CASE(tensorflow::int64) \
+ default: \
+ context->CtxFailure(__FILE__, __LINE__, \
+ tensorflow::errors::InvalidArgument( \
+ "Unsuppored tensor elements type")); \
+ break; \
+ }
+
void create_output_tensor(
tensorflow::OpKernelContext* context,
const tensorflow::Tensor& input_tensor,
const tensorflow::DataType& input_tensor_type,
- const tensorflow::PartialTensorShape& desired_shape_tensor) {
- auto desired_shape = desired_shape_tensor.dim_sizes();
-
- // obligatory type switch
- switch (input_tensor_type) {
- case tensorflow::DataTypeToEnum<float>::value:
- fill_periodic_tensor<float>(context, desired_shape, input_tensor);
+ const tensorflow::PartialTensorShape& desired_shape) {
+#define CASE(type) \
+ case tensorflow::DataTypeToEnum<type>::value: \
+ do_periodic_resample_op<type, Mode::kForward>( \
+ context, input_tensor.shape(), desired_shape, input_tensor); \
break;
- case tensorflow::DataTypeToEnum<double>::value:
- fill_periodic_tensor<double>(context, desired_shape, input_tensor);
- break;
- case tensorflow::DataTypeToEnum<tensorflow::int32>::value:
- fill_periodic_tensor<tensorflow::int32>(context, desired_shape,
- input_tensor);
- break;
- case tensorflow::DataTypeToEnum<tensorflow::int64>::value:
- fill_periodic_tensor<tensorflow::int64>(context, desired_shape,
- input_tensor);
+
+ DATA_TYPE_SWITCH(input_tensor_type, context, CASE);
+#undef CASE
+}
+
+void create_grad_tensor(tensorflow::OpKernelContext* context,
+ const tensorflow::Tensor& grad_tensor,
+ const tensorflow::DataType& grad_tensor_type,
+ const tensorflow::TensorShape& original_shape,
+ const tensorflow::PartialTensorShape& desired_shape) {
+#define CASE(type) \
+ case tensorflow::DataTypeToEnum<type>::value: \
+ do_periodic_resample_op<type, Mode::kGradient>( \
+ context, original_shape, desired_shape, grad_tensor); \
break;
- default:;
- }
+
+ DATA_TYPE_SWITCH(grad_tensor_type, context, CASE);
+#undef CASE
}
} // namespace
@@ -238,4 +400,25 @@ class PeriodicResampleOp : public tensorflow::OpKernel {
tensorflow::PartialTensorShape desired_shape;
};
+class PeriodicResampleOpGrad : public tensorflow::OpKernel {
+ public:
+ explicit PeriodicResampleOpGrad(tensorflow::OpKernelConstruction* context)
+ : tensorflow::OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("original_shape", &original_shape));
+ OP_REQUIRES_OK(context, context->GetAttr("desired_shape", &desired_shape));
+ }
+
+ void Compute(tensorflow::OpKernelContext* context) override {
+ const tensorflow::Tensor& grad_tensor = context->input(0);
+ const tensorflow::DataType grad_tensor_type = context->input_dtype(0);
+ create_grad_tensor(context, grad_tensor, grad_tensor_type, original_shape,
+ desired_shape);
+ }
+
+ private:
+ tensorflow::TensorShape original_shape;
+ tensorflow::PartialTensorShape desired_shape;
+};
+
#endif // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_
diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
index 82bd796956..fd38cd09b4 100644
--- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc
+++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
@@ -26,7 +26,42 @@ REGISTER_OP("PeriodicResample")
.Input("values: T")
.Attr("shape: shape")
.Output("output: T")
- .SetShapeFn(shape_inference::ExplicitShape)
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ tensorflow::PartialTensorShape desired_shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
+ shape_inference::ShapeHandle input_tensor_shape = c->input(0);
+ shape_inference::DimensionHandle num_input_elements =
+ c->NumElements(input_tensor_shape);
+ shape_inference::ShapeHandle result_shape_handle;
+ if (!shape_inference::InferenceContext::ValueKnown(num_input_elements)) {
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ desired_shape, &result_shape_handle));
+ } else {
+ const int rank = c->Rank(input_tensor_shape);
+ std::vector<tensorflow::int64> target_dimensions(rank);
+ tensorflow::int64 new_sliced_size = 1;
+ int adjustable_dimension = 0;
+ for (int i = 0; i < rank; ++i) {
+ if (desired_shape.dim_size(i) < 1) {
+ adjustable_dimension = i;
+ } else {
+ target_dimensions[i] = desired_shape.dim_size(i);
+ new_sliced_size *= target_dimensions[i];
+ }
+ }
+ target_dimensions[adjustable_dimension] =
+ shape_inference::InferenceContext::Value(
+ num_input_elements) / new_sliced_size;
+ tensorflow::TensorShape result_shape;
+ for (int i = 0; i < rank; ++i) {
+ result_shape.AddDim(target_dimensions[i]);
+ }
+ TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(
+ result_shape, &result_shape_handle));
+ }
+ c->set_output(0, result_shape_handle);
+ return Status::OK();
+ })
.Doc(R"doc(
Periodically resample elements of a tensor to conform to `shape`.
@@ -101,4 +136,20 @@ output: Periodically resampled tensor that has dimensions specified as in
)doc");
+
+REGISTER_OP("PeriodicResampleOpGrad")
+ .Attr("T: numbertype")
+ .Input("grad: T")
+ .Attr("original_shape: shape")
+ .Attr("desired_shape: shape")
+ .Output("grad_values: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ tensorflow::TensorShape original_shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("original_shape", &original_shape));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(original_shape, &s));
+ c->set_output(0, s);
+ return Status::OK();
+});
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc
new file mode 100644
index 0000000000..43b7c1799f
--- /dev/null
+++ b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(ArrayOpsTest, PeriodicResample_ShapeFn) {
+ ShapeInferenceTestOp op("PeriodicResample");
+ // Case 1: output shape can be fully inferreed.
+ PartialTensorShape shape({4, 4, -1});
+ TensorShapeProto shape_proto;
+ shape.AsProto(&shape_proto);
+
+ TF_ASSERT_OK(NodeDefBuilder("test", "PeriodicResample")
+ .Input({"values", 0, DT_INT32})
+ .Attr("shape", shape_proto)
+ .Finalize(&op.node_def));
+ INFER_OK(op, "[2,2,4]", "[4,4,1]");
+ // Case 2: output shape can not be inferred - report desired shape.
+ INFER_OK(op, "[2,2,?]", "[4,4,?]");
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index a25de55e18..31a6fe1d94 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -21,8 +21,11 @@ from __future__ import print_function
import numpy
from tensorflow.contrib.periodic_resample import periodic_resample
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -93,7 +96,6 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
def testPeriodicResampleErrors(self):
input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
with self.test_session():
- variables.global_variables_initializer().run()
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
'Dimension 3 input tensor has size 4, desired shape has size 1'):
@@ -103,6 +105,29 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
'4, to be the same as the length of the desired shape, 3'):
periodic_resample(input_tensor, [None, 4, 4]).eval()
+ def testPeriodicResampleGradient(self):
+ desired_shape = numpy.array([4, 4, None])
+ result_shape = (4, 4, 1)
+ input_shape = (2, 2, 4)
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.float32, shape=input_shape)
+ output = periodic_resample(x, desired_shape)
+ error = gradient_checker.compute_gradient_error(
+ x, input_shape, output, result_shape)
+ self.assertLess(error, 1e-4)
+
+ def testPeriodicResampleShapeInference(self):
+ with self.test_session() as sess:
+ # Case 1: output shape can be fully inferreed.
+ x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4))
+ output = periodic_resample(x, [4, 4, None])
+ self.assertEqual(output.shape, [4, 4, 1])
+ # Case 2: output shape can not be inferred - report desired shape.
+ x = array_ops.placeholder(dtypes.float32, shape=(2, 2, None))
+ output = periodic_resample(x, [4, 4, None])
+ self.assertTrue(output.shape.is_compatible_with([4, 4, None]))
+ self.assertEqual(output.shape[2].value, None)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
index 348623d8f8..470e300ccb 100644
--- a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
+++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
@@ -21,11 +21,17 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op
-from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample
+from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample, periodic_resample_op_grad
from tensorflow.contrib.util import loader
+from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
# pylint: enable=unused-import
_periodic_resample_op = loader.load_op_library(
resource_loader.get_path_to_datafile('_periodic_resample_op.so'))
+
+@ops.RegisterGradient("PeriodicResample")
+def _periodic_resample_grad_cc(op, grad):
+ return periodic_resample_op_grad(
+ grad, op.inputs[0].shape, op.get_attr('shape'))