aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/periodic_resample
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-18 11:48:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 11:55:03 -0700
commit148b4381fd0259cae441e459ec8ebe2c5d557722 (patch)
treec66c96ea6c60c63385b528dce195af802b8acf3b /tensorflow/contrib/periodic_resample
parentfc03fbff3dd7a58fa4f16226df4ada1f21f8b53f (diff)
Automated g4 rollback of changelist 201011811
PiperOrigin-RevId: 201033171
Diffstat (limited to 'tensorflow/contrib/periodic_resample')
-rw-r--r--tensorflow/contrib/periodic_resample/BUILD20
-rw-r--r--tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc5
-rw-r--r--tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h415
-rw-r--r--tensorflow/contrib/periodic_resample/ops/array_ops.cc53
-rw-r--r--tensorflow/contrib/periodic_resample/ops/array_ops_test.cc41
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py27
-rw-r--r--tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py8
7 files changed, 120 insertions, 449 deletions
diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD
index aad1ca04c5..6ca7fe8b6e 100644
--- a/tensorflow/contrib/periodic_resample/BUILD
+++ b/tensorflow/contrib/periodic_resample/BUILD
@@ -6,13 +6,12 @@ exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
- "tf_cc_test",
+ "py_test",
"tf_gen_op_libs",
"tf_custom_op_library",
"tf_custom_op_py_library",
"tf_gen_op_wrapper_py",
)
-load("//tensorflow:tensorflow.bzl", "py_test")
cc_library(
name = "all_ops",
@@ -85,23 +84,6 @@ py_test(
":init_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:framework_test_lib",
- "//tensorflow/python:gradient_checker",
- ],
-)
-
-tf_cc_test(
- name = "periodic_resample_op_cc_test",
- size = "small",
- srcs = [
- "ops/array_ops_test.cc",
- ],
- deps = [
- ":all_ops",
- "//tensorflow/core:framework",
- "//tensorflow/core:protos_all_proto",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
],
)
diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc
index 514689cf45..e18923c8aa 100644
--- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc
+++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc
@@ -22,9 +22,4 @@ namespace tensorflow {
REGISTER_KERNEL_BUILDER(Name("PeriodicResample").Device(DEVICE_CPU),
PeriodicResampleOp);
-
-REGISTER_KERNEL_BUILDER(Name("PeriodicResampleOpGrad")
- .Device(DEVICE_CPU),
- PeriodicResampleOpGrad);
-
} // namespace tensorflow
diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
index 42fba81a5c..3ab588c458 100644
--- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
+++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
@@ -25,202 +25,92 @@
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/util/work_sharder.h"
namespace {
-// Computes input tensor index for given output index during forward
-// propagation through periodic_resample operation.
-class InputIndexer {
- public:
- InputIndexer(const std::vector<tensorflow::int64>& output_dimensions,
- const tensorflow::TensorShape& input_shape,
- int adjustable_dimension)
- : output_dimensions_(output_dimensions),
- adjustable_dimension_(adjustable_dimension),
- rank_(input_shape.dims()),
- linear_output_index_(0),
- linear_input_index_(0),
- adjustable_dimension_carriage_sum_(0) {
- auto input_dimensions = TensorShapeToVector(input_shape);
- // factors by which input_dimensions increases/decreases w.r.t.
- // output_dimensions
- dimension_ceiling_ =
- ComputeDimensionCeiling(output_dimensions, input_dimensions);
- cumulative_dimensions_ = ComputeCumulativeDimensions();
-
- output_indices_.resize(output_dimensions_.size());
- input_indices_.resize(output_dimensions_.size());
-
- // Compute index_factors
- index_factors_.resize(rank_);
- tensorflow::int64 last_index_factor = 1;
- for (auto r = rank_ - 1; r >= 0; --r) {
- index_factors_[r] = last_index_factor;
- last_index_factor *= input_dimensions[r];
- }
- }
-
- tensorflow::int64 linear_input_index() const { return linear_input_index_; }
-
- void MoveToOutputIndex(tensorflow::int64 output_index);
- void IncrementOutputIndex();
-
- private:
- void RecomputeInputAdjustableDimensionIndex() {
- tensorflow::int64 index = adjustable_dimension_carriage_sum_;
- index *= output_dimensions_[adjustable_dimension_];
- index += output_indices_[adjustable_dimension_];
- input_indices_[adjustable_dimension_] = index;
- }
-
- std::vector<tensorflow::int64> TensorShapeToVector(
- const tensorflow::TensorShape& tensor_shape);
-
- std::vector<tensorflow::int64> ComputeDimensionCeiling(
- const std::vector<tensorflow::int64>& output_dimensions,
- const std::vector<tensorflow::int64>& input_dimensions);
-
- std::vector<tensorflow::int64> ComputeCumulativeDimensions();
-
- const std::vector<tensorflow::int64> output_dimensions_;
- std::vector<tensorflow::int64> dimension_ceiling_;
- std::vector<tensorflow::int64> index_factors_;
- std::vector<tensorflow::int64> cumulative_dimensions_;
- std::vector<tensorflow::int64> output_indices_;
- std::vector<tensorflow::int64> input_indices_;
-
- const int adjustable_dimension_;
- const int rank_;
- tensorflow::int64 linear_output_index_;
- tensorflow::int64 linear_input_index_;
- tensorflow::int64 adjustable_dimension_carriage_sum_;
-};
-
-void InputIndexer::MoveToOutputIndex(tensorflow::int64 output_index) {
- linear_output_index_ = output_index;
- linear_input_index_ = 0;
+template <class IndexVecT, class IndexT>
+IndexT compute_input_index(
+ IndexVecT* target_dimensions, const IndexT& output_index,
+ const IndexVecT& original_dimensions, const int& adjustable_dimension,
+ const std::vector<tensorflow::int64>& dimension_ceiling,
+ const std::vector<tensorflow::int64>& cumulative_dimensions, IndexT* result,
+ std::vector<IndexT>* output_indices, const int& rank) {
+ *result = 0;
+ output_indices->clear();
// un-rasterize the output index
auto last_reduced_i = output_index;
- for (auto r = rank_ - 1; r >= 0; --r) {
- output_indices_[r] = last_reduced_i % output_dimensions_[r];
+ for (auto r = rank - 1; r >= 0; --r) {
+ (*output_indices)[r] = last_reduced_i % (*target_dimensions)[r];
last_reduced_i =
- (last_reduced_i - output_indices_[r]) / output_dimensions_[r];
+ (last_reduced_i - (*output_indices)[r]) / (*target_dimensions)[r];
}
- tensorflow::int64 carriage_sum = 0;
- for (int qi = 0; qi < rank_; ++qi) {
- if (qi == adjustable_dimension_) continue;
- carriage_sum += cumulative_dimensions_[qi] *
- (output_indices_[qi] % dimension_ceiling_[qi]);
- }
- adjustable_dimension_carriage_sum_ = carriage_sum;
-
// rasterize the input index
- for (auto r = rank_ - 1; r >= 0; --r) {
- if (r != adjustable_dimension_) {
- input_indices_[r] = output_indices_[r] / dimension_ceiling_[r];
- } else {
- RecomputeInputAdjustableDimensionIndex();
- }
- }
- for (auto r = rank_ - 1; r >= 0; --r) {
- linear_input_index_ += index_factors_[r] * input_indices_[r];
- }
-}
-
-void InputIndexer::IncrementOutputIndex() {
- linear_output_index_++;
- for (auto r = rank_ - 1; r >= 0; --r) {
- auto old_carriage_sum_increment =
- cumulative_dimensions_[r] *
- (output_indices_[r] % dimension_ceiling_[r]);
- output_indices_[r] = (output_indices_[r] + 1) % output_dimensions_[r];
- if (r != adjustable_dimension_) {
- auto new_input_index = output_indices_[r] / dimension_ceiling_[r];
- linear_input_index_ +=
- (new_input_index - input_indices_[r]) * index_factors_[r];
-
- input_indices_[r] = new_input_index;
-
- auto new_carriage_sum_increment =
- cumulative_dimensions_[r] *
- (output_indices_[r] % dimension_ceiling_[r]);
-
- adjustable_dimension_carriage_sum_ = adjustable_dimension_carriage_sum_ -
- old_carriage_sum_increment +
- new_carriage_sum_increment;
- }
-
- if (output_indices_[r] != 0) {
- // No more carries to higher indices.
- break;
+ IndexT last_index_factor = 1;
+ for (auto r = rank - 1; r >= 0; --r) {
+ IndexT index = 0;
+ if (r != adjustable_dimension)
+ index = (*output_indices)[r] / dimension_ceiling[r];
+ else {
+ for (int qi = 0; qi < rank; ++qi) {
+ if (qi == adjustable_dimension) continue;
+ index += cumulative_dimensions[qi] *
+ ((*output_indices)[qi] % dimension_ceiling[qi]);
+ }
+ index *= (*target_dimensions)[adjustable_dimension];
+ index += (*output_indices)[r];
}
+ *result += last_index_factor * index;
+ last_index_factor *= original_dimensions[r];
}
- auto old_adjustable_dimension_input_index =
- input_indices_[adjustable_dimension_];
- RecomputeInputAdjustableDimensionIndex();
- linear_input_index_ += (input_indices_[adjustable_dimension_] -
- old_adjustable_dimension_input_index) *
- index_factors_[adjustable_dimension_];
-}
-std::vector<tensorflow::int64> InputIndexer::TensorShapeToVector(
- const tensorflow::TensorShape& tensor_shape) {
- std::vector<tensorflow::int64> result(tensor_shape.dims());
- int count = 0;
- for (const auto dim_info : tensor_shape) {
- result[count] = dim_info.size;
- ++count;
- }
- return result;
+ return *result;
}
-std::vector<tensorflow::int64> InputIndexer::ComputeDimensionCeiling(
- const std::vector<tensorflow::int64>& output_dimensions,
- const std::vector<tensorflow::int64>& input_dimensions) {
- std::vector<tensorflow::int64> dimension_ceiling(input_dimensions.size());
- for (size_t i = 0; i < input_dimensions.size(); ++i) {
- dimension_ceiling[i] = (output_dimensions[i] + input_dimensions[i] - 1) /
- input_dimensions[i];
- }
- return dimension_ceiling;
-}
+template <class InputDataT,
+ class IndexVecT> // both types are needed here b/c IndexVecT and
+ // InputDataT are not related
+ void
+ fill_periodic_tensor(
+ tensorflow::OpKernelContext* context,
+ const IndexVecT& desired_shape,
+ const tensorflow::Tensor& input_tensor) {
+ // input is a strided array (last index is fastest, C-ordered)
+ auto input = input_tensor.flat<InputDataT>();
+ const int rank = input_tensor.dims();
+ // original and target dimensions
+ std::vector<tensorflow::int64> original_dimensions(rank),
+ target_dimensions(rank);
+ tensorflow::int64 total_size(input_tensor.NumElements()), new_sliced_size(1);
+ // factors by which original_dimensions increases/decreases w.r.t.
+ // target_dimensions
+ std::vector<tensorflow::int64> dimension_ceiling(rank),
+ cumulative_dimensions(rank);
+ // index of adjustable dimension
+ int adjustable_dimension;
+ tensorflow::TensorShape output_shape;
-std::vector<tensorflow::int64> InputIndexer::ComputeCumulativeDimensions() {
- std::vector<tensorflow::int64> cumulative_dimensions(rank_);
- int count = 0;
- for (int i = 0; i < rank_; ++i) {
- if (count == 0) {
- cumulative_dimensions[count] = 1;
- } else {
- cumulative_dimensions[count] =
- cumulative_dimensions[count - 1] * dimension_ceiling_[count - 1];
- }
- ++count;
- }
- return cumulative_dimensions;
-}
+ // requires that the rank of the input tensor and length of the desired shape
+ // are equal
+ OP_REQUIRES(context, rank == desired_shape.size(),
+ tensorflow::errors::InvalidArgument(
+ "periodic_resample expects the rank of the input tensor, ",
+ rank, ", to be the same as the length of the desired shape, ",
+ desired_shape.size(), "."));
-template <typename IndexVecT>
-void process_desired_shape(tensorflow::OpKernelContext* context,
- const tensorflow::TensorShape& input_tensor_shape,
- const IndexVecT& desired_shape,
- int* adjustable_dimension,
- std::vector<tensorflow::int64>* target_dimensions,
- tensorflow::int64* output_size) {
- tensorflow::int64 new_sliced_size = 1;
bool found = false;
- const int rank = input_tensor_shape.dims();
+ const auto& input_tensor_shape = input_tensor.shape();
+
for (int i = 0; i < rank; ++i) {
+ // if (desired_shape(i) < 1) {
if (desired_shape[i] < 1) {
// only one index can be adjustable
OP_REQUIRES(context, !found,
tensorflow::errors::InvalidArgument(
"periodic_resample expects only "
"one index to be marked as adjustable."));
- *adjustable_dimension = i;
+ adjustable_dimension = i;
found = true;
} else {
OP_REQUIRES(
@@ -232,8 +122,9 @@ void process_desired_shape(tensorflow::OpKernelContext* context,
i, " input tensor has size ", input_tensor_shape.dim_size(i),
", desired shape has size ", desired_shape[i], "."));
- (*target_dimensions)[i] = desired_shape[i];
- new_sliced_size *= (*target_dimensions)[i];
+ // target_dimensions[i] = desired_shape(i);
+ target_dimensions[i] = desired_shape[i];
+ new_sliced_size *= target_dimensions[i];
}
}
// at least one index needs to be adjustable
@@ -241,50 +132,26 @@ void process_desired_shape(tensorflow::OpKernelContext* context,
tensorflow::errors::InvalidArgument(
"periodic_resample expects at least "
"one index to be marked as adjustable."));
- (*target_dimensions)[*adjustable_dimension] =
- input_tensor_shape.num_elements() / new_sliced_size;
-
- *output_size = new_sliced_size * (*target_dimensions)[*adjustable_dimension];
-}
-
-// Heuristic number based on measurements on
-// Intel(R) Core(TM) i7-4930K CPU @ 3.40GHz
-const tensorflow::int64 costPerFillIndex = 35;
-enum class Mode {
- kForward,
- kGradient
-};
-
-// Computes either periodic_resample operation output or gradients for it,
-// depending on |mode|.
-// |original_shape| is always shape of input to periodic_resample operation.
-// |source_tensor| is either source for periodic_resample (for forward mode)
-// or gradients tensor.
-// |desired_shape| is always shape, provided by user, to which forward
-// propagation attempts resample input tensor.
-template <class InputDataT, Mode mode>
-void
-do_periodic_resample_op(tensorflow::OpKernelContext* context,
- const tensorflow::TensorShape& original_shape,
- const tensorflow::PartialTensorShape& desired_shape,
- const tensorflow::Tensor& source_tensor) {
- const int rank = source_tensor.dims();
+ int count = 0;
+ for (const auto dim_info : input_tensor.shape()) {
+ original_dimensions[count] = dim_info.size;
+ ++count;
+ }
- // requires that the rank of the input tensor and length of the desired shape
- // are equal
- OP_REQUIRES(context, rank == desired_shape.dims(),
- tensorflow::errors::InvalidArgument(
- "periodic_resample expects the rank of the input tensor, ",
- rank, ", to be the same as the length of the desired shape, ",
- desired_shape.dims(), "."));
+ target_dimensions[adjustable_dimension] = total_size / new_sliced_size;
- std::vector<tensorflow::int64> target_dimensions(rank);
- tensorflow::int64 new_size = 0;
- // index of adjustable dimension
- int adjustable_dimension = 0;
- process_desired_shape(context, original_shape, desired_shape.dim_sizes(),
- &adjustable_dimension, &target_dimensions, &new_size);
+ count = 0;
+ for (int i = 0; i < input_tensor.shape().dims(); ++i) {
+ dimension_ceiling[count] = tensorflow::int64(std::ceil(
+ float(target_dimensions[count]) / float(original_dimensions[count])));
+ if (count == 0)
+ cumulative_dimensions[count] = 1;
+ else
+ cumulative_dimensions[count] =
+ cumulative_dimensions[count - 1] * dimension_ceiling[count - 1];
+ ++count;
+ }
// ensure that the new dimension is greater than zero
OP_REQUIRES(context, target_dimensions[adjustable_dimension] > 0,
@@ -293,14 +160,11 @@ do_periodic_resample_op(tensorflow::OpKernelContext* context,
"adjustable dimension, ",
adjustable_dimension, ", isn't greater than zero, ",
target_dimensions[adjustable_dimension], "."));
- tensorflow::TensorShape output_shape;
- if (mode == Mode::kForward) {
- for (int i = 0; i < rank; ++i) {
- output_shape.AddDim(target_dimensions[i]);
- }
- } else {
- output_shape = original_shape;
+ for (int i = 0; i < rank; ++i) {
+ output_shape.AddDim(target_dimensions[i]);
}
+ const auto new_size =
+ new_sliced_size * target_dimensions[adjustable_dimension];
// Create an output tensor and attach it to the current context
tensorflow::Tensor* output_tensor = nullptr;
@@ -308,73 +172,47 @@ do_periodic_resample_op(tensorflow::OpKernelContext* context,
context->allocate_output(0, output_shape, &output_tensor));
auto output = output_tensor->flat<InputDataT>();
- // input is a strided array (last index is fastest, C-ordered)
- auto input = source_tensor.flat<InputDataT>();
+ // memory is allocated for these variables outside the inner loop for
+ // efficiency (although, I could create a separate class scope for
+ // this purpose instead)
+ tensorflow::int64 result = 0;
+ std::vector<tensorflow::int64> output_indices(target_dimensions.size());
// Fill output tensor with periodically resampled input tensor values
- InputIndexer input_indexer(target_dimensions, original_shape,
- adjustable_dimension);
-
- auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
- auto fill_output_tensor = [&input_indexer, &output, &input](
- tensorflow::int64 start, tensorflow::int64 limit) {
- InputIndexer local_indexer(input_indexer);
- local_indexer.MoveToOutputIndex(start);
- for (tensorflow::int64 output_index = start; output_index < limit;
- ++output_index) {
- if (mode == Mode::kForward) {
- output(output_index) = input(local_indexer.linear_input_index());
- } else {
- output(local_indexer.linear_input_index()) = input(output_index);
- }
- local_indexer.IncrementOutputIndex();
- }
- };
- ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers,
- new_size, costPerFillIndex, fill_output_tensor);
-}
-
-#define DATA_TYPE_SWITCH(data_type, context, CASE) \
- switch (data_type) { \
- CASE(float) \
- CASE(double) \
- CASE(tensorflow::int32) \
- CASE(tensorflow::int64) \
- default: \
- context->CtxFailure(__FILE__, __LINE__, \
- tensorflow::errors::InvalidArgument( \
- "Unsuppored tensor elements type")); \
- break; \
+ for (tensorflow::int64 output_index = 0; output_index < new_size;
+ ++output_index) {
+ output(output_index) = input(compute_input_index(
+ &target_dimensions, output_index, original_dimensions,
+ adjustable_dimension, dimension_ceiling, cumulative_dimensions, &result,
+ &output_indices, rank));
}
+}
void create_output_tensor(
tensorflow::OpKernelContext* context,
const tensorflow::Tensor& input_tensor,
const tensorflow::DataType& input_tensor_type,
- const tensorflow::PartialTensorShape& desired_shape) {
-#define CASE(type) \
- case tensorflow::DataTypeToEnum<type>::value: \
- do_periodic_resample_op<type, Mode::kForward>( \
- context, input_tensor.shape(), desired_shape, input_tensor); \
- break;
+ const tensorflow::PartialTensorShape& desired_shape_tensor) {
+ auto desired_shape = desired_shape_tensor.dim_sizes();
- DATA_TYPE_SWITCH(input_tensor_type, context, CASE);
-#undef CASE
-}
-
-void create_grad_tensor(tensorflow::OpKernelContext* context,
- const tensorflow::Tensor& grad_tensor,
- const tensorflow::DataType& grad_tensor_type,
- const tensorflow::TensorShape& original_shape,
- const tensorflow::PartialTensorShape& desired_shape) {
-#define CASE(type) \
- case tensorflow::DataTypeToEnum<type>::value: \
- do_periodic_resample_op<type, Mode::kGradient>( \
- context, original_shape, desired_shape, grad_tensor); \
+ // obligatory type switch
+ switch (input_tensor_type) {
+ case tensorflow::DataTypeToEnum<float>::value:
+ fill_periodic_tensor<float>(context, desired_shape, input_tensor);
break;
-
- DATA_TYPE_SWITCH(grad_tensor_type, context, CASE);
-#undef CASE
+ case tensorflow::DataTypeToEnum<double>::value:
+ fill_periodic_tensor<double>(context, desired_shape, input_tensor);
+ break;
+ case tensorflow::DataTypeToEnum<tensorflow::int32>::value:
+ fill_periodic_tensor<tensorflow::int32>(context, desired_shape,
+ input_tensor);
+ break;
+ case tensorflow::DataTypeToEnum<tensorflow::int64>::value:
+ fill_periodic_tensor<tensorflow::int64>(context, desired_shape,
+ input_tensor);
+ break;
+ default:;
+ }
}
} // namespace
@@ -400,25 +238,4 @@ class PeriodicResampleOp : public tensorflow::OpKernel {
tensorflow::PartialTensorShape desired_shape;
};
-class PeriodicResampleOpGrad : public tensorflow::OpKernel {
- public:
- explicit PeriodicResampleOpGrad(tensorflow::OpKernelConstruction* context)
- : tensorflow::OpKernel(context) {
- OP_REQUIRES_OK(context,
- context->GetAttr("original_shape", &original_shape));
- OP_REQUIRES_OK(context, context->GetAttr("desired_shape", &desired_shape));
- }
-
- void Compute(tensorflow::OpKernelContext* context) override {
- const tensorflow::Tensor& grad_tensor = context->input(0);
- const tensorflow::DataType grad_tensor_type = context->input_dtype(0);
- create_grad_tensor(context, grad_tensor, grad_tensor_type, original_shape,
- desired_shape);
- }
-
- private:
- tensorflow::TensorShape original_shape;
- tensorflow::PartialTensorShape desired_shape;
-};
-
#endif // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_
diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
index fd38cd09b4..82bd796956 100644
--- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc
+++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
@@ -26,42 +26,7 @@ REGISTER_OP("PeriodicResample")
.Input("values: T")
.Attr("shape: shape")
.Output("output: T")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- tensorflow::PartialTensorShape desired_shape;
- TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
- shape_inference::ShapeHandle input_tensor_shape = c->input(0);
- shape_inference::DimensionHandle num_input_elements =
- c->NumElements(input_tensor_shape);
- shape_inference::ShapeHandle result_shape_handle;
- if (!shape_inference::InferenceContext::ValueKnown(num_input_elements)) {
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- desired_shape, &result_shape_handle));
- } else {
- const int rank = c->Rank(input_tensor_shape);
- std::vector<tensorflow::int64> target_dimensions(rank);
- tensorflow::int64 new_sliced_size = 1;
- int adjustable_dimension = 0;
- for (int i = 0; i < rank; ++i) {
- if (desired_shape.dim_size(i) < 1) {
- adjustable_dimension = i;
- } else {
- target_dimensions[i] = desired_shape.dim_size(i);
- new_sliced_size *= target_dimensions[i];
- }
- }
- target_dimensions[adjustable_dimension] =
- shape_inference::InferenceContext::Value(
- num_input_elements) / new_sliced_size;
- tensorflow::TensorShape result_shape;
- for (int i = 0; i < rank; ++i) {
- result_shape.AddDim(target_dimensions[i]);
- }
- TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(
- result_shape, &result_shape_handle));
- }
- c->set_output(0, result_shape_handle);
- return Status::OK();
- })
+ .SetShapeFn(shape_inference::ExplicitShape)
.Doc(R"doc(
Periodically resample elements of a tensor to conform to `shape`.
@@ -136,20 +101,4 @@ output: Periodically resampled tensor that has dimensions specified as in
)doc");
-
-REGISTER_OP("PeriodicResampleOpGrad")
- .Attr("T: numbertype")
- .Input("grad: T")
- .Attr("original_shape: shape")
- .Attr("desired_shape: shape")
- .Output("grad_values: T")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- tensorflow::TensorShape original_shape;
- TF_RETURN_IF_ERROR(c->GetAttr("original_shape", &original_shape));
- shape_inference::ShapeHandle s;
- TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(original_shape, &s));
- c->set_output(0, s);
- return Status::OK();
-});
-
} // namespace tensorflow
diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc
deleted file mode 100644
index 43b7c1799f..0000000000
--- a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc
+++ /dev/null
@@ -1,41 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/shape_inference_testutil.h"
-#include "tensorflow/core/framework/tensor_shape.pb.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-
-TEST(ArrayOpsTest, PeriodicResample_ShapeFn) {
- ShapeInferenceTestOp op("PeriodicResample");
- // Case 1: output shape can be fully inferreed.
- PartialTensorShape shape({4, 4, -1});
- TensorShapeProto shape_proto;
- shape.AsProto(&shape_proto);
-
- TF_ASSERT_OK(NodeDefBuilder("test", "PeriodicResample")
- .Input({"values", 0, DT_INT32})
- .Attr("shape", shape_proto)
- .Finalize(&op.node_def));
- INFER_OK(op, "[2,2,4]", "[4,4,1]");
- // Case 2: output shape can not be inferred - report desired shape.
- INFER_OK(op, "[2,2,?]", "[4,4,?]");
-}
-
-} // end namespace tensorflow
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index 31a6fe1d94..a25de55e18 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -21,11 +21,8 @@ from __future__ import print_function
import numpy
from tensorflow.contrib.periodic_resample import periodic_resample
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -96,6 +93,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
def testPeriodicResampleErrors(self):
input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
with self.test_session():
+ variables.global_variables_initializer().run()
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
'Dimension 3 input tensor has size 4, desired shape has size 1'):
@@ -105,29 +103,6 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
'4, to be the same as the length of the desired shape, 3'):
periodic_resample(input_tensor, [None, 4, 4]).eval()
- def testPeriodicResampleGradient(self):
- desired_shape = numpy.array([4, 4, None])
- result_shape = (4, 4, 1)
- input_shape = (2, 2, 4)
- with self.test_session() as sess:
- x = array_ops.placeholder(dtypes.float32, shape=input_shape)
- output = periodic_resample(x, desired_shape)
- error = gradient_checker.compute_gradient_error(
- x, input_shape, output, result_shape)
- self.assertLess(error, 1e-4)
-
- def testPeriodicResampleShapeInference(self):
- with self.test_session() as sess:
- # Case 1: output shape can be fully inferreed.
- x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4))
- output = periodic_resample(x, [4, 4, None])
- self.assertEqual(output.shape, [4, 4, 1])
- # Case 2: output shape can not be inferred - report desired shape.
- x = array_ops.placeholder(dtypes.float32, shape=(2, 2, None))
- output = periodic_resample(x, [4, 4, None])
- self.assertTrue(output.shape.is_compatible_with([4, 4, None]))
- self.assertEqual(output.shape[2].value, None)
-
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
index 470e300ccb..348623d8f8 100644
--- a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
+++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
@@ -21,17 +21,11 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op
-from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample, periodic_resample_op_grad
+from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample
from tensorflow.contrib.util import loader
-from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
# pylint: enable=unused-import
_periodic_resample_op = loader.load_op_library(
resource_loader.get_path_to_datafile('_periodic_resample_op.so'))
-
-@ops.RegisterGradient("PeriodicResample")
-def _periodic_resample_grad_cc(op, grad):
- return periodic_resample_op_grad(
- grad, op.inputs[0].shape, op.get_attr('shape'))