aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/cc/BUILD1
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake1
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt2
-rw-r--r--tensorflow/core/BUILD3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Roll.pbtxt52
-rw-r--r--tensorflow/core/graph/testlib.cc12
-rw-r--r--tensorflow/core/graph/testlib.h4
-rw-r--r--tensorflow/core/kernels/BUILD39
-rw-r--r--tensorflow/core/kernels/roll_op.cc334
-rw-r--r--tensorflow/core/kernels/roll_op_test.cc484
-rw-r--r--tensorflow/core/ops/manip_ops.cc33
-rw-r--r--tensorflow/python/BUILD36
-rw-r--r--tensorflow/python/__init__.py2
-rw-r--r--tensorflow/python/kernel_tests/BUILD13
-rw-r--r--tensorflow/python/kernel_tests/manip_ops_test.py137
-rw-r--r--tensorflow/python/ops/gradients_impl.py1
-rw-r--r--tensorflow/python/ops/manip_grad.py32
-rw-r--r--tensorflow/python/ops/manip_ops.py36
-rw-r--r--tensorflow/python/ops/standard_ops.py4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.manip.pbtxt7
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
22 files changed, 1237 insertions, 1 deletions
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index c9ade5fb83..9060c19e9d 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -433,6 +433,7 @@ tf_gen_op_wrappers_cc(
"linalg_ops",
"logging_ops",
"lookup_ops",
+ "manip_ops",
"math_ops",
"nn_ops",
"no_op",
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 138993db35..c42bc35ce7 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -30,6 +30,7 @@ set(tf_op_lib_names
"list_ops"
"lookup_ops"
"logging_ops"
+ "manip_ops"
"math_ops"
"nn_ops"
"no_op"
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 8862390d2b..b7c816c24f 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -335,6 +335,7 @@ GENERATE_PYTHON_OP_LIB("list_ops")
GENERATE_PYTHON_OP_LIB("logging_ops")
GENERATE_PYTHON_OP_LIB("lookup_ops")
GENERATE_PYTHON_OP_LIB("nn_ops")
+GENERATE_PYTHON_OP_LIB("manip_ops")
GENERATE_PYTHON_OP_LIB("parsing_ops")
GENERATE_PYTHON_OP_LIB("random_ops")
GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops"
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 5f27566398..9a1ab50317 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -91,6 +91,7 @@ tensorflow/core/kernels/reduction_ops_max.cc
tensorflow/core/kernels/reduction_ops_common.cc
tensorflow/core/kernels/reduction_ops_any.cc
tensorflow/core/kernels/reduction_ops_all.cc
+tensorflow/core/kernels/roll_op.cc
tensorflow/core/kernels/queue_ops.cc
tensorflow/core/kernels/queue_base.cc
tensorflow/core/kernels/pooling_ops_common.cc
@@ -270,6 +271,7 @@ tensorflow/core/ops/parsing_ops.cc
tensorflow/core/ops/no_op.cc
tensorflow/core/ops/nn_ops.cc
tensorflow/core/ops/nn_grad.cc
+tensorflow/core/ops/manip_ops.cc
tensorflow/core/ops/math_ops.cc
tensorflow/core/ops/math_grad.cc
tensorflow/core/ops/logging_ops.cc
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 3b4a10eedb..90c2823ea4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -610,6 +610,7 @@ tf_gen_op_libs(
"list_ops",
"lookup_ops",
"logging_ops",
+ "manip_ops",
"math_ops",
"nn_ops",
"no_op",
@@ -692,6 +693,7 @@ cc_library(
":list_ops_op_lib",
":logging_ops_op_lib",
":lookup_ops_op_lib",
+ ":manip_ops_op_lib",
":math_ops_op_lib",
":nn_ops_op_lib",
":no_op_op_lib",
@@ -829,6 +831,7 @@ cc_library(
"//tensorflow/core/kernels:list_kernels",
"//tensorflow/core/kernels:lookup",
"//tensorflow/core/kernels:logging",
+ "//tensorflow/core/kernels:manip",
"//tensorflow/core/kernels:math",
"//tensorflow/core/kernels:multinomial_op",
"//tensorflow/core/kernels:nn",
diff --git a/tensorflow/core/api_def/base_api/api_def_Roll.pbtxt b/tensorflow/core/api_def/base_api/api_def_Roll.pbtxt
new file mode 100644
index 0000000000..b308ad1f9d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Roll.pbtxt
@@ -0,0 +1,52 @@
+op {
+ graph_op_name: "Roll"
+ in_arg {
+ name: "shift"
+ description: <<END
+Dimension must be 0-D or 1-D. `shift[i]` specifies the number of places by which
+elements are shifted positively (towards larger indices) along the dimension
+specified by `axis[i]`. Negative shifts will roll the elements in the opposite
+direction.
+END
+ }
+ in_arg {
+ name: "axis"
+ description: <<END
+Dimension must be 0-D or 1-D. `axis[i]` specifies the dimension that the shift
+`shift[i]` should occur. If the same axis is referenced more than once, the
+total shift for that axis will be the sum of all the shifts that belong to that
+axis.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Has the same shape and size as the input. The elements are shifted
+positively (towards larger indices) by the offsets of `shift` along the
+dimensions of `axis`.
+END
+ }
+ summary: "Rolls the elements of a tensor along an axis."
+ description: <<END
+The elements are shifted positively (towards larger indices) by the offset of
+`shift` along the dimension of `axis`. Negative `shift` values will shift
+elements in the opposite direction. Elements that roll passed the last position
+will wrap around to the first and vice versa. Multiple shifts along multiple
+axes may be specified.
+
+For example:
+
+```
+# 't' is [0, 1, 2, 3, 4]
+roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2]
+
+# shifting along multiple dimensions
+# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
+roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]]
+
+# shifting along the same axis multiple times
+# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
+roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
+```
+END
+}
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index 172471e34b..0d88d1ff72 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -40,7 +40,7 @@ REGISTER_KERNEL_BUILDER(
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(
Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
// Register the HostConst Op
// Returns a constant tensor on the host. Useful for writing C++ tests
@@ -273,6 +273,16 @@ Node* Reverse(Graph* g, Node* tensor, Node* axis) {
return Binary(g, "ReverseV2", tensor, axis);
}
+Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry())
+ .Input(input)
+ .Input(shift)
+ .Input(axis)
+ .Finalize(g, &ret));
+ return ret;
+}
+
Node* Error(Graph* g, Node* input, const string& errmsg) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 06597778bb..eb9038d619 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -117,6 +117,10 @@ Node* RandomGamma(Graph* g, Node* shape, Node* alpha);
// Output dtype determined by lam.
Node* RandomPoisson(Graph* g, Node* shape, Node* lam);
+// Rolls tensor by an offset of <shift> along the corresponding
+// <axis> dimensions.
+Node* Roll(Graph* g, Node* input, Node* shift, Node* axis);
+
// Generates random parameters from the truncated standard normal distribution
// of the nput shape
Node* TruncatedNormal(Graph* g, Node* input, DataType dtype);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index db309fc9da..e7192ec42f 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2589,6 +2589,45 @@ tf_cc_tests(
],
)
+cc_library(
+ name = "manip",
+ deps = [
+ ":roll_op",
+ ],
+)
+
+MANIP_DEPS = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:manip_ops_op_lib",
+ "//third_party/eigen3",
+]
+
+tf_kernel_library(
+ name = "roll_op",
+ prefix = "roll_op",
+ deps = MANIP_DEPS,
+)
+
+tf_cc_test(
+ name = "roll_op_test",
+ size = "small",
+ srcs = ["roll_op_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ ":roll_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
MATH_DEPS = [
":bounds_check",
":fill_functor",
diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc
new file mode 100644
index 0000000000..bcbdbee058
--- /dev/null
+++ b/tensorflow/core/kernels/roll_op.cc
@@ -0,0 +1,334 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/register_types_traits.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+#define EIGEN_USE_THREADS
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+// dim_size - the size of each dimension
+// dim_range - the number of indices over in the flattened tensor
+// you need to skip in order to make it over from one side of a dimension
+// to the other. Used to make the shifts wrap around after a threshold.
+// threshold - the index for each dimension that the roll starts to wrap
+// back to the front
+template <typename T>
+void DoRoll(OpKernelContext* context, const int64 num_elements,
+ const int num_dims, const gtl::ArraySlice<int>& dim_size,
+ const T* input, T* output, const gtl::ArraySlice<int>& threshold,
+ const gtl::ArraySlice<int64>& dim_range) {
+ auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range](
+ int64 start, int64 end) {
+ // array of indices for each dimension
+ gtl::InlinedVector<int, 4> indices(num_dims);
+ int offset = 0; // the shift along the flattened tensor for current element
+ // initialize indices and offset
+ for (int i = 0; i < num_dims; i++) {
+ // stride is the number of indices over in the flattened tensor
+ // you need to skip in order to make it over to an adjacent element
+ // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
+ const int64 stride = dim_range[i] / dim_size[i];
+ const int shift = dim_size[i] - threshold[i];
+ const int indx = (start / stride) % dim_size[i];
+ indices[i] = indx;
+ // calculate dimension index after the shift
+ const int shifted_indx = (indx + shift) % dim_size[i];
+ offset += (shifted_indx - indx) * stride;
+ }
+
+ for (int64 i = start; i < end; i++) {
+ output[i + offset] = input[i];
+ // create next combination of indices
+ // while at it adjust offset if needed
+ for (int j = num_dims - 1; j >= 0; j--) {
+ const int indx = (indices[j] + 1) % dim_size[j];
+ indices[j] = indx;
+ if (indx != 0) {
+ if (indx == threshold[j]) { // we've reached the threshold
+ // dim_range[j] = threshold[j] + shift[j]
+ // offset = shift[j] + ... other offsets
+ // offset - dim_range[j] = -threshold[j] + ... other offsets
+ // thus we undo our previous offset as well as add a new offset of
+ // -threshold[j] in one operation
+ offset -= dim_range[j]; // now wraps around
+ }
+ break; // indx != 0 don't need to carry
+ } else if (threshold[j] != 0) { // if threshold is 0 shift is 0
+ offset += dim_range[j]; // indx became 0 so reverse wrap around
+ }
+ }
+ }
+ };
+ // Shard
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ // 15 - expiramentally determined with float and bool types
+ const int cost_per_element = 15 * sizeof(T); // rough esitmate
+ Shard(worker_threads->num_threads, worker_threads->workers, num_elements,
+ cost_per_element, std::move(work));
+}
+
+// dim_size - the size of each dimension
+// dim_range - the number of indices over in the flattened tensor
+// you need to skip in order to make it over from one side of a dimension
+// to the other. Used to make the shifts wrap around after a threshold.
+// threshold - the index for each dimension that the roll starts to wrap
+// back to the front
+// isd - inner shift dimension
+template <typename T>
+// Use memcpy to copy memory in groups when the data type supports memcpy
+void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements,
+ const int num_dims, const gtl::ArraySlice<int>& dim_size,
+ const T* input, T* output,
+ const gtl::ArraySlice<int>& threshold,
+ const gtl::ArraySlice<int64>& dim_range,
+ const int64 isd) {
+ auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd](
+ int64 start, int64 end) {
+ // the number of indices over in the flattened tensor you need to skip in
+ // order to make it over from one side of the isd to the other
+ const int64 isd_range = std::max<int>(dim_range[isd], 1);
+ // the distance along the flattend tensor to the next element in the isd
+ const int64 isd_stride = isd_range / std::max<int>(dim_size[isd], 1);
+
+ // start and end represent the i-th group currently so we will convert
+ // them into numbers representing the i-th elements.
+ // there are 2 groups per isd one for all elements before threshold[isd]
+ // and another for all elements after threshold[isd].
+ const int64 start_remainder = (start % 2) * threshold[isd] * isd_stride;
+ const int64 end_remainder = (end % 2) * threshold[isd] * isd_stride;
+ start = (start / 2) * isd_range + start_remainder;
+ end = (end / 2) * isd_range + end_remainder;
+
+ const T* in_ptr = &input[0];
+ T* out_ptr = &output[0];
+ in_ptr += start;
+ out_ptr += start;
+
+ // array of indices for each dimension
+ // indicies = [i, j, k, l, m, n]
+ gtl::InlinedVector<int, 4> indicies(num_dims);
+ // the offset needed to make all inner non-shifting dimensions become 0
+ int64 remainder_offset = 0;
+ // initialize indicies
+ for (int i = 0; i < num_dims; i++) {
+ // stride is the number of indices over in the flattened tensor
+ // you need to skip in order to make it over to an adjacent element
+ // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
+ const int64 stride = dim_range[i] / dim_size[i];
+ const int shift = dim_size[i] - threshold[i];
+ const int indx = (start / stride) % dim_size[i];
+ indicies[i] = indx;
+ // calculate dimension index after the shift
+ int out_indx = (indx + shift) % dim_size[i];
+ if (i > isd) {
+ // trailing zeroes for indices after the inner shifted dimension
+ out_indx = 0;
+ remainder_offset += (out_indx - indx) * stride;
+ }
+ out_ptr += (out_indx - indx) * stride;
+ }
+ // set trailing zeroes for indices after the inner shifted dimension
+ for (int i = num_dims - 1; i > isd; i--) indicies[i] = 0;
+
+ // the number of indices in the isd dimension the next group will skip
+ // to make it to the next threshold or end point
+ int isd_indx_skip = 0;
+ // the size of the next group
+ int64 group_size = 0;
+ // initialize isd_indx_skip and group_size
+ if (indicies[isd] < threshold[isd]) {
+ isd_indx_skip = threshold[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride + remainder_offset;
+ } else {
+ isd_indx_skip = dim_size[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride + remainder_offset;
+ }
+
+ int64 i = start;
+ while (i < end) {
+ // copy group of elements
+ memcpy(out_ptr, in_ptr, group_size * sizeof(T));
+
+ // shift i and the pointers over to the next group position
+ i += group_size;
+ out_ptr += group_size;
+ in_ptr += group_size;
+
+ // produce next combination of indices and adjust the out_ptr position
+ // to fix the offset if necessary
+ // the isd (inner shift dim) should skip to next threshold or endpoint
+ // all dimensions to the left increment by 1 when a digit is carried
+ // all dimensions to the right remain set to 0
+ // +1 +1 +1 +isd_indx_skip
+ // indicies = [i, j, k, l, 0, 0]
+ // ^isd
+ for (int j = isd; j >= 0; j--) {
+ int inc = 1;
+ if (j == isd) inc = isd_indx_skip;
+ const int indx = (indicies[j] + inc) % dim_size[j];
+ indicies[j] = indx;
+ if (indx != 0) {
+ if (indx == threshold[j]) {
+ out_ptr -= dim_range[j]; // now wraps around
+ }
+ break; // indx != 0 don't need to carry
+ } else if (threshold[j] != 0) { // if threshold is 0 shift is 0
+ out_ptr += dim_range[j]; // indx became 0 so reverse wrap around
+ }
+ }
+
+ // set isd_indx_skip and group_size for next iteration
+ if (indicies[isd] < threshold[isd]) {
+ isd_indx_skip = threshold[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride;
+ } else {
+ isd_indx_skip = dim_size[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride;
+ }
+ }
+ };
+ // Shard
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ const int64 ave_group_size = dim_range[isd] / 2;
+ const int total_work = 2 * num_elements / std::max<int>(dim_range[isd], 1);
+ // 25000 - expiramentally determined with float and bool types
+ const int cost_per_group = 25000 * sizeof(T) * ave_group_size;
+ Shard(worker_threads->num_threads, worker_threads->workers, total_work,
+ cost_per_group, std::move(work));
+}
+
+template <typename Device, typename T, typename Tshift, typename Taxis>
+class RollOp : public OpKernel {
+ public:
+ explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& input = context->input(0);
+ const Tensor& shift = context->input(1);
+ const Tensor& axis = context->input(2);
+
+ auto shift_flat = shift.flat<Tshift>();
+ auto axis_flat = axis.flat<Taxis>();
+
+ OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
+ errors::InvalidArgument("input must be 1-D or higher"));
+ OP_REQUIRES(context, shift.shape().dims() <= 1,
+ errors::InvalidArgument(
+ "shift must be a scalar or a 1-D vector. Found: ",
+ shift.shape().DebugString()));
+ OP_REQUIRES(context, axis.shape().dims() <= 1,
+ errors::InvalidArgument(
+ "axis must be a scalar or a 1-D vector. Found: ",
+ axis.shape().DebugString()));
+ OP_REQUIRES(
+ context, shift.shape() == axis.shape(),
+ errors::InvalidArgument("shift and axis must have the same size"));
+ const int64 num_elements = input.NumElements();
+ const int num_shifts = static_cast<int>(shift_flat.size());
+ const int num_dims = input.dims();
+
+ // if there are any duplicate axes, shift_mod_sum will have the
+ // total modulo sum of shifts for each dimension
+ gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0);
+ for (int i = 0; i < num_shifts; i++) {
+ const int axis = axis_flat(i);
+ OP_REQUIRES(context, axis < num_dims,
+ errors::InvalidArgument("axis ", axis, " is out of range"));
+ const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
+ const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
+ // modulo that works with negatives: ((x % y) + y) % y
+ shift_mod_sum[axis] = (sum % ds + ds) % ds;
+ }
+ // the size of each dimension
+ gtl::InlinedVector<int, 4> dim_size(num_dims);
+ // threshold[i] is the index that the roll starts to wrap back to the front
+ gtl::InlinedVector<int, 4> threshold(num_dims);
+ // dim_range is the number of indices over in the flattened tensor
+ // you need to skip in order to make it over from one side of a dimension
+ // to the other. Used to make the shifts wrap around after a threshold.
+ gtl::InlinedVector<int64, 4> dim_range(num_dims);
+ int64 dim_size_prod = 1; // dimension size product
+ // inner shift dimension (inner most shifted dimension)
+ int64 isd = 0;
+ for (int i = num_dims - 1; i >= 0; i--) {
+ if (isd == 0 && shift_mod_sum[i] != 0) isd = i;
+ const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1);
+ dim_size[i] = ds;
+ threshold[i] = (ds - shift_mod_sum[i]) % ds;
+ dim_size_prod *= static_cast<int64>(input.dim_size(i));
+ dim_range[i] = dim_size_prod;
+ }
+
+ Tensor* output = NULL;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ auto input_flat = input.flat<T>().data();
+ auto output_flat = output->flat<T>().data();
+
+ if (std::is_same<Device, CPUDevice>::value) {
+ if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
+ // V2 copies memory in groups instead of element by element
+ DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
+ input_flat, output_flat, threshold, dim_range, isd);
+ } else {
+ // incase memcpy does not work for current data type
+ DoRoll<T>(context, num_elements, num_dims, dim_size, input_flat,
+ output_flat, threshold, dim_range);
+ }
+ }
+ }
+};
+
+// Register the CPU kernels.
+#define REGISTER_CPU(type) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tshift") \
+ .TypeConstraint<int32>("Taxis"), \
+ RollOp<CPUDevice, type, int32, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tshift") \
+ .TypeConstraint<int32>("Taxis"), \
+ RollOp<CPUDevice, type, int64, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tshift") \
+ .TypeConstraint<int64>("Taxis"), \
+ RollOp<CPUDevice, type, int32, int64>) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tshift") \
+ .TypeConstraint<int64>("Taxis"), \
+ RollOp<CPUDevice, type, int64, int64>)
+
+TF_CALL_ALL_TYPES(REGISTER_CPU);
+#undef REGISTER_CPU
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/roll_op_test.cc b/tensorflow/core/kernels/roll_op_test.cc
new file mode 100644
index 0000000000..90b6f8d0f3
--- /dev/null
+++ b/tensorflow/core/kernels/roll_op_test.cc
@@ -0,0 +1,484 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <functional>
+#include <memory>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class RollOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType data_type, DataType index_type) {
+ TF_ASSERT_OK(NodeDefBuilder("myop", "Roll")
+ .Input(FakeInput(data_type))
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(index_type))
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(RollOpTest, ScalarIndices) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
+ test::FillValues<float>(&expected, {2, 3, 4, 0, 1});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ScalarIndices_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({5}), {"a", "b", "c", "d", "e"});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({5}));
+ test::FillValues<string>(&expected, {"c", "d", "e", "a", "b"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ScalarIndices_Complex) {
+ MakeOp(DT_COMPLEX64, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<std::complex<float>>(
+ TensorShape({5}), {std::complex<float>(0, 10), std::complex<float>(1, 11),
+ std::complex<float>(2, 12), std::complex<float>(3, 13),
+ std::complex<float>(4, 14)});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_COMPLEX64, TensorShape({5}));
+ test::FillValues<std::complex<float>>(
+ &expected, {std::complex<float>(2, 12), std::complex<float>(3, 13),
+ std::complex<float>(4, 14), std::complex<float>(0, 10),
+ std::complex<float>(1, 11)});
+ test::ExpectTensorEqual<std::complex<float>>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({3, 5}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int32>(TensorShape({2}), {2, -1});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({3, 5}));
+ test::FillValues<float>(&expected,
+ {6, 7, 8, 9, 5, 11, 12, 13, 14, 10, 1, 2, 3, 4, 0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({3, 5}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
+ "k", "l", "m", "n", "o"});
+ AddInputFromArray<int32>(TensorShape({2}), {2, -1});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({3, 5}));
+ test::FillValues<string>(&expected, {"g", "h", "i", "j", "f", "l", "m", "n",
+ "o", "k", "b", "c", "d", "e", "a"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ AddInputFromArray<int32>(TensorShape({3}), {1, -1, -1});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 3}));
+ test::FillValues<float>(&expected, {10, 11, 9, 7, 8, 6, 4, 5, 3, 1, 2, 0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(
+ TensorShape({2, 2, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ AddInputFromArray<int32>(TensorShape({3}), {1, -1, -1});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({2, 2, 3}));
+ test::FillValues<string>(
+ &expected, {"k", "l", "j", "h", "i", "g", "e", "f", "d", "b", "c", "a"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD64) {
+ MakeOp(DT_FLOAT, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int64>(TensorShape({2}), {-1, 4});
+ AddInputFromArray<int64>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
+ test::FillValues<float>(&expected,
+ {5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 2, 0, 1});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD64_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({5, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
+ "k", "l", "m", "n", "o"});
+ AddInputFromArray<int64>(TensorShape({2}), {-1, 4});
+ AddInputFromArray<int64>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({5, 3}));
+ test::FillValues<string>(&expected, {"f", "d", "e", "i", "g", "h", "l", "j",
+ "k", "o", "m", "n", "c", "a", "b"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD64) {
+ MakeOp(DT_FLOAT, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({4, 1, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ AddInputFromArray<int64>(TensorShape({3}), {4, 3, 2});
+ AddInputFromArray<int64>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 1, 3}));
+ test::FillValues<float>(&expected, {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD64_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<string>(
+ TensorShape({4, 1, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ AddInputFromArray<int64>(TensorShape({3}), {4, 3, 2});
+ AddInputFromArray<int64>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({4, 1, 3}));
+ test::FillValues<string>(
+ &expected, {"b", "c", "a", "e", "f", "d", "h", "i", "g", "k", "l", "j"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroShift_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 3}));
+ test::FillValues<float>(&expected, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroShift_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(
+ TensorShape({2, 2, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({2, 2, 3}));
+ test::FillValues<string>(
+ &expected, {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroSize_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 0, 0}), {});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 0, 0}));
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroSize_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({5, 0, 0}), {});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({5, 0, 0}));
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, OneSize_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({1, 1, 1}), {5});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1}));
+ test::FillValues<float>(&expected, {5});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, OneSize_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({1, 1, 1}), {"a"});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({1, 1, 1}));
+ test::FillValues<string>(&expected, {"a"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, MultiShifts_TwoD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({3, 5}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int32>(TensorShape({4}), {-2, 2, -1, 1});
+ AddInputFromArray<int32>(TensorShape({4}), {1, 0, 0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({3, 5}));
+ test::FillValues<float>(&expected,
+ {11, 12, 13, 14, 10, 1, 2, 3, 4, 0, 6, 7, 8, 9, 5});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, MultiShifts_TwoD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({3, 5}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
+ "k", "l", "m", "n", "o"});
+ AddInputFromArray<int32>(TensorShape({4}), {-2, 2, -1, 1});
+ AddInputFromArray<int32>(TensorShape({4}), {1, 0, 0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({3, 5}));
+ test::FillValues<string>(&expected, {"l", "m", "n", "o", "k", "b", "c", "d",
+ "e", "a", "g", "h", "i", "j", "f"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Error_InputMustBeVectorOrHigher) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({}), {7});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString()).contains("input must be 1-D or higher"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_AxisMustBeScalarOrVector) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({1, 2}), {0, 1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("axis must be a scalar or a 1-D vector"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_ShiftMustBeScalarOrVector) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({1, 2}), {0, 1});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("shift must be a scalar or a 1-D vector"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_ShiftAndAxisMustBeSameSize) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({1}), {1});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("shift and axis must have the same size"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_AxisOutOfRange) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({4}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString()).contains("is out of range")) << s;
+}
+
+// isd - (inner shift dimension) The inner most dimension to be shifted.
+// All outer dimensions will also be shifted for testing.
+static Graph* RollGraph(const TensorShape& shape, int isd) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor input(DT_FLOAT, shape);
+ input.flat<float>().setRandom();
+ const int dims = static_cast<int>(input.dims());
+ Tensor shift(DT_INT32, TensorShape({dims}));
+ for (int i = 0; i < dims; i++) {
+ // shift the inner shift dimension and all outer dimensions
+ shift.flat<int32>()(i) = (i <= isd) ? 2 : 0;
+ }
+ Tensor axis(DT_INT32, TensorShape({dims}));
+ for (int i = 0; i < dims; i++) {
+ axis.flat<int32>()(i) = i;
+ }
+ test::graph::Roll(g, test::graph::Constant(g, input),
+ test::graph::Constant(g, shift),
+ test::graph::Constant(g, axis));
+ return g;
+}
+
+#define BM_ROLL_OUTER(DEVICE) \
+ static void BM_##DEVICE##_roll_outer(int iters, int rows, int columns) { \
+ TensorShape shape{rows, columns}; \
+ const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
+ testing::ItemsProcessed(num_items); \
+ testing::BytesProcessed(num_items * sizeof(float)); \
+ testing::UseRealTime(); \
+ test::Benchmark(#DEVICE, RollGraph(shape, 0)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_roll_outer) \
+ ->ArgPair(256, 256) \
+ ->ArgPair(512, 512) \
+ ->ArgPair(1024, 1024) \
+ ->ArgPair(2048, 2048)
+
+#define BM_ROLL_ALL(DEVICE) \
+ static void BM_##DEVICE##_roll_all(int iters, int rows, int columns) { \
+ TensorShape shape{rows, columns}; \
+ const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
+ testing::ItemsProcessed(num_items); \
+ testing::BytesProcessed(num_items * sizeof(float)); \
+ testing::UseRealTime(); \
+ test::Benchmark(#DEVICE, RollGraph(shape, 1)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_roll_all) \
+ ->ArgPair(256, 256) \
+ ->ArgPair(512, 512) \
+ ->ArgPair(1024, 1024) \
+ ->ArgPair(2048, 2048)
+
+BM_ROLL_OUTER(cpu);
+BM_ROLL_ALL(cpu);
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/manip_ops.cc b/tensorflow/core/ops/manip_ops.cc
new file mode 100644
index 0000000000..95b4774fe6
--- /dev/null
+++ b/tensorflow/core/ops/manip_ops.cc
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Roll")
+ .Input("input: T")
+ .Input("shift: Tshift")
+ .Input("axis: Taxis")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tshift: {int32,int64}")
+ .Attr("Taxis: {int32,int64}")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
+} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index a323d5bc39..c73d6c37ee 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -76,6 +76,7 @@ py_library(
":layers",
":lib",
":list_ops",
+ ":manip_ops",
":math_ops",
":metrics",
":nn",
@@ -1395,6 +1396,14 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "manip_ops_gen",
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ "//tensorflow/python/kernel_tests:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "math_ops_gen",
visibility = [
"//learning/brain/google/python/ops:__pkg__",
@@ -1726,6 +1735,8 @@ py_library(
":linalg_grad",
":linalg_ops",
":logging_ops",
+ ":manip_grad",
+ ":manip_ops",
":math_grad",
":math_ops",
":platform",
@@ -1849,6 +1860,29 @@ py_library(
)
py_library(
+ name = "manip_grad",
+ srcs = ["ops/manip_grad.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":control_flow_ops",
+ ":framework_for_generated_wrappers",
+ ":manip_ops",
+ ],
+)
+
+py_library(
+ name = "manip_ops",
+ srcs = ["ops/manip_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dtypes",
+ ":framework_ops",
+ ":manip_ops_gen",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "logging_ops",
srcs = ["ops/logging_ops.py"],
srcs_version = "PY2AND3",
@@ -2310,6 +2344,8 @@ py_library(
":linalg_ops",
":logging_ops",
":lookup_ops",
+ ":manip_grad",
+ ":manip_ops",
":math_grad",
":math_ops",
":numerics",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index bc9ddec2a5..ea7604d30f 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -84,6 +84,7 @@ from tensorflow.python.feature_column import feature_column_lib as feature_colum
from tensorflow.python.layers import layers
from tensorflow.python.ops import bitwise_ops as bitwise
from tensorflow.python.ops import image_ops as image
+from tensorflow.python.ops import manip_ops as manip
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn
from tensorflow.python.ops import sets
@@ -241,6 +242,7 @@ _allowed_symbols.extend([
'linalg',
'logging',
'losses',
+ 'manip',
'metrics',
'newaxis',
'nn',
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index c87b7652ad..3a6058054b 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1602,6 +1602,19 @@ cuda_py_test(
)
cuda_py_test(
+ name = "manip_ops_test",
+ size = "small",
+ srcs = ["manip_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:manip_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+cuda_py_test(
name = "matmul_op_test",
size = "small",
srcs = ["matmul_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
new file mode 100644
index 0000000000..3044b21aa4
--- /dev/null
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -0,0 +1,137 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for manip_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import manip_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.platform import test as test_lib
+
+import numpy as np
+
+# pylint: disable=g-import-not-at-top
+try:
+ from distutils.version import StrictVersion as Version
+ # numpy.roll for multiple shifts was introduced in numpy version 1.12.0
+ NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version('1.12.0')
+except ImportError:
+ NP_ROLL_CAN_MULTISHIFT = False
+# pylint: enable=g-import-not-at-top
+
+class RollTest(test_util.TensorFlowTestCase):
+ def _testRoll(self, np_input, shift, axis):
+ expected_roll = np.roll(np_input, shift, axis)
+ with self.test_session():
+ roll = manip_ops.roll(np_input, shift, axis)
+ self.assertAllEqual(roll.eval(), expected_roll)
+
+ def _testGradient(self, np_input, shift, axis):
+ with self.test_session():
+ inx = constant_op.constant(np_input.tolist())
+ xs = list(np_input.shape)
+ y = manip_ops.roll(inx, shift, axis)
+ # Expected y's shape to be the same
+ ys = xs
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ inx, xs, y, ys, x_init_value=np_input)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def _testAll(self, np_input, shift, axis):
+ self._testRoll(np_input, shift, axis)
+ if np_input.dtype == np.float32:
+ self._testGradient(np_input, shift, axis)
+
+ def testIntTypes(self):
+ for t in [np.int32, np.int64]:
+ self._testAll(np.random.randint(-100, 100, (5)).astype(t), 3, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ self._testAll(np.random.randint(-100, 100, (4, 4, 3)).astype(t),
+ [1, -2, 3], [0, 1, 2])
+ self._testAll(np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t),
+ [0, 1, -2], [1, 2, 3])
+
+ def testFloatTypes(self):
+ for t in [np.float32, np.float64]:
+ self._testAll(np.random.rand(5).astype(t), 2, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ self._testAll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0])
+ self._testAll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2])
+
+ def testComplexTypes(self):
+ for t in [np.complex64, np.complex128]:
+ x = np.random.rand(4, 4).astype(t)
+ self._testAll(x + 1j * x, 2, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ x = np.random.rand(2, 5).astype(t)
+ self._testAll(x + 1j * x, [1, 2], [1, 0])
+ x = np.random.rand(3, 2, 1, 1).astype(t)
+ self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
+
+
+ def testRollInputMustVectorHigherRaises(self):
+ tensor = 7
+ shift = 1
+ axis = 0
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "input must be 1-D or higher"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollAxisMustBeScalarOrVectorRaises(self):
+ tensor = [[1, 2],
+ [3, 4]]
+ shift = 1
+ axis = [[0, 1]]
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "axis must be a scalar or a 1-D vector"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollShiftMustBeScalarOrVectorRaises(self):
+ tensor = [[1, 2],
+ [3, 4]]
+ shift = [[0, 1]]
+ axis = 1
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "shift must be a scalar or a 1-D vector"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollShiftAndAxisMustBeSameSizeRaises(self):
+ tensor = [[1, 2],
+ [3, 4]]
+ shift = [1]
+ axis = [0, 1]
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "shift and axis must have the same size"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollAxisOutOfRangeRaises(self):
+ tensor = [1, 2]
+ shift = 1
+ axis = 1
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "is out of range"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+if __name__ == "__main__":
+ test_lib.main()
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 314726ede6..230b6c5946 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import image_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import
from tensorflow.python.ops import logging_ops # pylint: disable=unused-import
+from tensorflow.python.ops import manip_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
diff --git a/tensorflow/python/ops/manip_grad.py b/tensorflow/python/ops/manip_grad.py
new file mode 100644
index 0000000000..573e8c0a0d
--- /dev/null
+++ b/tensorflow/python/ops/manip_grad.py
@@ -0,0 +1,32 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Gradients for operators defined in manip_ops.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import manip_ops
+
+
+@ops.RegisterGradient("Roll")
+def _RollGrad(op, grad):
+ # The gradient is just the roll reversed
+ shift = op.inputs[1]
+ axis = op.inputs[2]
+ roll_grad = manip_ops.roll(grad, -shift, axis)
+ return roll_grad, None, None
diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py
new file mode 100644
index 0000000000..c5f39784f4
--- /dev/null
+++ b/tensorflow/python/ops/manip_ops.py
@@ -0,0 +1,36 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators for manipulating tensors.
+
+@@roll
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
+from tensorflow.python.util.all_util import remove_undocumented
+
+# pylint: disable=protected-access
+def roll(input, shift, axis):
+ return _gen_manip_ops.roll(input, shift, axis)
+
+roll.__doc__ = _gen_manip_ops.roll.__doc__
+# pylint: enable=protected-access
+
+_allowed_symbols = ['roll']
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 30bf4e4ef1..737b923415 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -26,6 +26,7 @@ import sys as _sys
from tensorflow.python.ops import array_grad
from tensorflow.python.ops import data_flow_grad
from tensorflow.python.ops import math_grad
+from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import spectral_grad
from tensorflow.python.ops import state_grad
@@ -59,6 +60,7 @@ from tensorflow.python.ops.logging_ops import Print
from tensorflow.python.ops.logging_ops import get_summary_op
from tensorflow.python.ops.lookup_ops import initialize_all_tables
from tensorflow.python.ops.lookup_ops import tables_initializer
+from tensorflow.python.ops.manip_ops import *
from tensorflow.python.ops.math_ops import *
from tensorflow.python.ops.numerics import *
from tensorflow.python.ops.parsing_ops import *
@@ -105,6 +107,7 @@ from tensorflow.python.ops import init_ops as _init_ops
from tensorflow.python.ops import io_ops as _io_ops
from tensorflow.python.ops import linalg_ops as _linalg_ops
from tensorflow.python.ops import logging_ops as _logging_ops
+from tensorflow.python.ops import manip_ops as _manip_ops
from tensorflow.python.ops import math_ops as _math_ops
from tensorflow.python.ops import numerics as _numerics
from tensorflow.python.ops import parsing_ops as _parsing_ops
@@ -280,6 +283,7 @@ remove_undocumented(__name__, _allowed_symbols,
_io_ops,
_linalg_ops,
_logging_ops,
+ _manip_ops,
_math_ops,
_numerics,
_parsing_ops,
diff --git a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt
new file mode 100644
index 0000000000..0b84165285
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.manip"
+tf_module {
+ member_method {
+ name: "roll"
+ argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index dc7c3a2f45..e8890e9cc0 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -397,6 +397,10 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "manip"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "metrics"
mtype: "<type \'module\'>"
}