aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jean Flaherty <kobejean@me.com>2018-01-29 17:03:31 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-01-29 16:03:31 -0800
commit4ab2d8531c461169cd6a33bc0fef1129b419e9df (patch)
treef32970777a805b9ce6e57e933be813dd6d5e1247
parent8c1f3b3fcdc5f897dbce3b7315c1b7000069f9c4 (diff)
Tensor roll op implementation (#14953)
* Half migrated to manip Half migrated to tf.manip * Roll op: polymorphism & GPU attempt * Roll op: Added support for gradients Added compile script * Rebase for roll op code * Roll op: Migrated to manip namespace * Roll op: Supports CPU thread pooling fix namespace error * Remove roll from user_ops * Roll op: Optimization * Roll op: Pylint fix * Roll op: Updated documentation * Roll op: Two versions for CPU thread pooling * Roll op: Huge CPU speed up Fixed thread pooling issue that was due to a bad cost_per_unit parameter Also improved readability * Roll op: Rough draft of DoRollV2 DoRollV2 copies memory in groups instead of element by element. Not thoroughly tested yet. Polished DoRollV2 algorithm * Roll op: Restrict tensor size for GPU implementation * Roll op: Fixed clang-format and missing include * Roll op: Minor change * Roll op GPU bug fix Roll op GPU bug fix GPU bug fix Roll op GPU bug fix Roll op GPU fix Roll GPU test BUILD update * Roll op: Remove GPU code Fully remove roll op GPU code Remove compile_cpu.sh * Roll op: Fixes problems with array_ops_test.py and a size 1 dimension bug * Roll op: Migrated to manip Migrated to tf.manip Roll op registered Roll op uses InlinedVector Small improvements * Roll op: Revert array op changes * Roll op: Api def fix * Roll op: review changes * Roll op: API review changes Roll op: Docstring fix * Roll op: Review changes round 1 * Roll op: resolve conflicts * Roll op: Resolve conflicts * Roll op: clang-tidy * Roll op: Review round 2 changes Roll op: fixed BUILD file Roll op: api docs update * Roll op: failure fixes 1 - updates goldens and fixes api compatibility issue - fixes python op test issue for windows - fixes makefile issues * Roll op: Windows CMake failure fix Windows CMake checks were failing because numpy was on an older version that did not support np.roll on multiple shifts that was use to check the correctness of tf.manip.roll. manip_ops_test.py now checks for numpy version 1.12.0 before testing multiple shifts, otherwise it'll just test single shift roll. * Roll op: pylint changes
-rw-r--r--tensorflow/cc/BUILD1
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake1
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt2
-rw-r--r--tensorflow/core/BUILD3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Roll.pbtxt52
-rw-r--r--tensorflow/core/graph/testlib.cc12
-rw-r--r--tensorflow/core/graph/testlib.h4
-rw-r--r--tensorflow/core/kernels/BUILD39
-rw-r--r--tensorflow/core/kernels/roll_op.cc334
-rw-r--r--tensorflow/core/kernels/roll_op_test.cc484
-rw-r--r--tensorflow/core/ops/manip_ops.cc33
-rw-r--r--tensorflow/python/BUILD36
-rw-r--r--tensorflow/python/__init__.py2
-rw-r--r--tensorflow/python/kernel_tests/BUILD13
-rw-r--r--tensorflow/python/kernel_tests/manip_ops_test.py137
-rw-r--r--tensorflow/python/ops/gradients_impl.py1
-rw-r--r--tensorflow/python/ops/manip_grad.py32
-rw-r--r--tensorflow/python/ops/manip_ops.py36
-rw-r--r--tensorflow/python/ops/standard_ops.py4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.manip.pbtxt7
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
22 files changed, 1237 insertions, 1 deletions
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index c9ade5fb83..9060c19e9d 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -433,6 +433,7 @@ tf_gen_op_wrappers_cc(
"linalg_ops",
"logging_ops",
"lookup_ops",
+ "manip_ops",
"math_ops",
"nn_ops",
"no_op",
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 138993db35..c42bc35ce7 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -30,6 +30,7 @@ set(tf_op_lib_names
"list_ops"
"lookup_ops"
"logging_ops"
+ "manip_ops"
"math_ops"
"nn_ops"
"no_op"
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 8862390d2b..b7c816c24f 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -335,6 +335,7 @@ GENERATE_PYTHON_OP_LIB("list_ops")
GENERATE_PYTHON_OP_LIB("logging_ops")
GENERATE_PYTHON_OP_LIB("lookup_ops")
GENERATE_PYTHON_OP_LIB("nn_ops")
+GENERATE_PYTHON_OP_LIB("manip_ops")
GENERATE_PYTHON_OP_LIB("parsing_ops")
GENERATE_PYTHON_OP_LIB("random_ops")
GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops"
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 5f27566398..9a1ab50317 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -91,6 +91,7 @@ tensorflow/core/kernels/reduction_ops_max.cc
tensorflow/core/kernels/reduction_ops_common.cc
tensorflow/core/kernels/reduction_ops_any.cc
tensorflow/core/kernels/reduction_ops_all.cc
+tensorflow/core/kernels/roll_op.cc
tensorflow/core/kernels/queue_ops.cc
tensorflow/core/kernels/queue_base.cc
tensorflow/core/kernels/pooling_ops_common.cc
@@ -270,6 +271,7 @@ tensorflow/core/ops/parsing_ops.cc
tensorflow/core/ops/no_op.cc
tensorflow/core/ops/nn_ops.cc
tensorflow/core/ops/nn_grad.cc
+tensorflow/core/ops/manip_ops.cc
tensorflow/core/ops/math_ops.cc
tensorflow/core/ops/math_grad.cc
tensorflow/core/ops/logging_ops.cc
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 3b4a10eedb..90c2823ea4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -610,6 +610,7 @@ tf_gen_op_libs(
"list_ops",
"lookup_ops",
"logging_ops",
+ "manip_ops",
"math_ops",
"nn_ops",
"no_op",
@@ -692,6 +693,7 @@ cc_library(
":list_ops_op_lib",
":logging_ops_op_lib",
":lookup_ops_op_lib",
+ ":manip_ops_op_lib",
":math_ops_op_lib",
":nn_ops_op_lib",
":no_op_op_lib",
@@ -829,6 +831,7 @@ cc_library(
"//tensorflow/core/kernels:list_kernels",
"//tensorflow/core/kernels:lookup",
"//tensorflow/core/kernels:logging",
+ "//tensorflow/core/kernels:manip",
"//tensorflow/core/kernels:math",
"//tensorflow/core/kernels:multinomial_op",
"//tensorflow/core/kernels:nn",
diff --git a/tensorflow/core/api_def/base_api/api_def_Roll.pbtxt b/tensorflow/core/api_def/base_api/api_def_Roll.pbtxt
new file mode 100644
index 0000000000..b308ad1f9d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Roll.pbtxt
@@ -0,0 +1,52 @@
+op {
+ graph_op_name: "Roll"
+ in_arg {
+ name: "shift"
+ description: <<END
+Dimension must be 0-D or 1-D. `shift[i]` specifies the number of places by which
+elements are shifted positively (towards larger indices) along the dimension
+specified by `axis[i]`. Negative shifts will roll the elements in the opposite
+direction.
+END
+ }
+ in_arg {
+ name: "axis"
+ description: <<END
+Dimension must be 0-D or 1-D. `axis[i]` specifies the dimension that the shift
+`shift[i]` should occur. If the same axis is referenced more than once, the
+total shift for that axis will be the sum of all the shifts that belong to that
+axis.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Has the same shape and size as the input. The elements are shifted
+positively (towards larger indices) by the offsets of `shift` along the
+dimensions of `axis`.
+END
+ }
+ summary: "Rolls the elements of a tensor along an axis."
+ description: <<END
+The elements are shifted positively (towards larger indices) by the offset of
+`shift` along the dimension of `axis`. Negative `shift` values will shift
+elements in the opposite direction. Elements that roll passed the last position
+will wrap around to the first and vice versa. Multiple shifts along multiple
+axes may be specified.
+
+For example:
+
+```
+# 't' is [0, 1, 2, 3, 4]
+roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2]
+
+# shifting along multiple dimensions
+# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
+roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]]
+
+# shifting along the same axis multiple times
+# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
+roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
+```
+END
+}
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index 172471e34b..0d88d1ff72 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -40,7 +40,7 @@ REGISTER_KERNEL_BUILDER(
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(
Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
// Register the HostConst Op
// Returns a constant tensor on the host. Useful for writing C++ tests
@@ -273,6 +273,16 @@ Node* Reverse(Graph* g, Node* tensor, Node* axis) {
return Binary(g, "ReverseV2", tensor, axis);
}
+Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry())
+ .Input(input)
+ .Input(shift)
+ .Input(axis)
+ .Finalize(g, &ret));
+ return ret;
+}
+
Node* Error(Graph* g, Node* input, const string& errmsg) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 06597778bb..eb9038d619 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -117,6 +117,10 @@ Node* RandomGamma(Graph* g, Node* shape, Node* alpha);
// Output dtype determined by lam.
Node* RandomPoisson(Graph* g, Node* shape, Node* lam);
+// Rolls tensor by an offset of <shift> along the corresponding
+// <axis> dimensions.
+Node* Roll(Graph* g, Node* input, Node* shift, Node* axis);
+
// Generates random parameters from the truncated standard normal distribution
// of the nput shape
Node* TruncatedNormal(Graph* g, Node* input, DataType dtype);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index db309fc9da..e7192ec42f 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2589,6 +2589,45 @@ tf_cc_tests(
],
)
+cc_library(
+ name = "manip",
+ deps = [
+ ":roll_op",
+ ],
+)
+
+MANIP_DEPS = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:manip_ops_op_lib",
+ "//third_party/eigen3",
+]
+
+tf_kernel_library(
+ name = "roll_op",
+ prefix = "roll_op",
+ deps = MANIP_DEPS,
+)
+
+tf_cc_test(
+ name = "roll_op_test",
+ size = "small",
+ srcs = ["roll_op_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ ":roll_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
MATH_DEPS = [
":bounds_check",
":fill_functor",
diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc
new file mode 100644
index 0000000000..bcbdbee058
--- /dev/null
+++ b/tensorflow/core/kernels/roll_op.cc
@@ -0,0 +1,334 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/register_types_traits.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+#define EIGEN_USE_THREADS
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+// dim_size - the size of each dimension
+// dim_range - the number of indices over in the flattened tensor
+// you need to skip in order to make it over from one side of a dimension
+// to the other. Used to make the shifts wrap around after a threshold.
+// threshold - the index for each dimension that the roll starts to wrap
+// back to the front
+template <typename T>
+void DoRoll(OpKernelContext* context, const int64 num_elements,
+ const int num_dims, const gtl::ArraySlice<int>& dim_size,
+ const T* input, T* output, const gtl::ArraySlice<int>& threshold,
+ const gtl::ArraySlice<int64>& dim_range) {
+ auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range](
+ int64 start, int64 end) {
+ // array of indices for each dimension
+ gtl::InlinedVector<int, 4> indices(num_dims);
+ int offset = 0; // the shift along the flattened tensor for current element
+ // initialize indices and offset
+ for (int i = 0; i < num_dims; i++) {
+ // stride is the number of indices over in the flattened tensor
+ // you need to skip in order to make it over to an adjacent element
+ // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
+ const int64 stride = dim_range[i] / dim_size[i];
+ const int shift = dim_size[i] - threshold[i];
+ const int indx = (start / stride) % dim_size[i];
+ indices[i] = indx;
+ // calculate dimension index after the shift
+ const int shifted_indx = (indx + shift) % dim_size[i];
+ offset += (shifted_indx - indx) * stride;
+ }
+
+ for (int64 i = start; i < end; i++) {
+ output[i + offset] = input[i];
+ // create next combination of indices
+ // while at it adjust offset if needed
+ for (int j = num_dims - 1; j >= 0; j--) {
+ const int indx = (indices[j] + 1) % dim_size[j];
+ indices[j] = indx;
+ if (indx != 0) {
+ if (indx == threshold[j]) { // we've reached the threshold
+ // dim_range[j] = threshold[j] + shift[j]
+ // offset = shift[j] + ... other offsets
+ // offset - dim_range[j] = -threshold[j] + ... other offsets
+ // thus we undo our previous offset as well as add a new offset of
+ // -threshold[j] in one operation
+ offset -= dim_range[j]; // now wraps around
+ }
+ break; // indx != 0 don't need to carry
+ } else if (threshold[j] != 0) { // if threshold is 0 shift is 0
+ offset += dim_range[j]; // indx became 0 so reverse wrap around
+ }
+ }
+ }
+ };
+ // Shard
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ // 15 - expiramentally determined with float and bool types
+ const int cost_per_element = 15 * sizeof(T); // rough esitmate
+ Shard(worker_threads->num_threads, worker_threads->workers, num_elements,
+ cost_per_element, std::move(work));
+}
+
+// dim_size - the size of each dimension
+// dim_range - the number of indices over in the flattened tensor
+// you need to skip in order to make it over from one side of a dimension
+// to the other. Used to make the shifts wrap around after a threshold.
+// threshold - the index for each dimension that the roll starts to wrap
+// back to the front
+// isd - inner shift dimension
+template <typename T>
+// Use memcpy to copy memory in groups when the data type supports memcpy
+void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements,
+ const int num_dims, const gtl::ArraySlice<int>& dim_size,
+ const T* input, T* output,
+ const gtl::ArraySlice<int>& threshold,
+ const gtl::ArraySlice<int64>& dim_range,
+ const int64 isd) {
+ auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd](
+ int64 start, int64 end) {
+ // the number of indices over in the flattened tensor you need to skip in
+ // order to make it over from one side of the isd to the other
+ const int64 isd_range = std::max<int>(dim_range[isd], 1);
+ // the distance along the flattend tensor to the next element in the isd
+ const int64 isd_stride = isd_range / std::max<int>(dim_size[isd], 1);
+
+ // start and end represent the i-th group currently so we will convert
+ // them into numbers representing the i-th elements.
+ // there are 2 groups per isd one for all elements before threshold[isd]
+ // and another for all elements after threshold[isd].
+ const int64 start_remainder = (start % 2) * threshold[isd] * isd_stride;
+ const int64 end_remainder = (end % 2) * threshold[isd] * isd_stride;
+ start = (start / 2) * isd_range + start_remainder;
+ end = (end / 2) * isd_range + end_remainder;
+
+ const T* in_ptr = &input[0];
+ T* out_ptr = &output[0];
+ in_ptr += start;
+ out_ptr += start;
+
+ // array of indices for each dimension
+ // indicies = [i, j, k, l, m, n]
+ gtl::InlinedVector<int, 4> indicies(num_dims);
+ // the offset needed to make all inner non-shifting dimensions become 0
+ int64 remainder_offset = 0;
+ // initialize indicies
+ for (int i = 0; i < num_dims; i++) {
+ // stride is the number of indices over in the flattened tensor
+ // you need to skip in order to make it over to an adjacent element
+ // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
+ const int64 stride = dim_range[i] / dim_size[i];
+ const int shift = dim_size[i] - threshold[i];
+ const int indx = (start / stride) % dim_size[i];
+ indicies[i] = indx;
+ // calculate dimension index after the shift
+ int out_indx = (indx + shift) % dim_size[i];
+ if (i > isd) {
+ // trailing zeroes for indices after the inner shifted dimension
+ out_indx = 0;
+ remainder_offset += (out_indx - indx) * stride;
+ }
+ out_ptr += (out_indx - indx) * stride;
+ }
+ // set trailing zeroes for indices after the inner shifted dimension
+ for (int i = num_dims - 1; i > isd; i--) indicies[i] = 0;
+
+ // the number of indices in the isd dimension the next group will skip
+ // to make it to the next threshold or end point
+ int isd_indx_skip = 0;
+ // the size of the next group
+ int64 group_size = 0;
+ // initialize isd_indx_skip and group_size
+ if (indicies[isd] < threshold[isd]) {
+ isd_indx_skip = threshold[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride + remainder_offset;
+ } else {
+ isd_indx_skip = dim_size[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride + remainder_offset;
+ }
+
+ int64 i = start;
+ while (i < end) {
+ // copy group of elements
+ memcpy(out_ptr, in_ptr, group_size * sizeof(T));
+
+ // shift i and the pointers over to the next group position
+ i += group_size;
+ out_ptr += group_size;
+ in_ptr += group_size;
+
+ // produce next combination of indices and adjust the out_ptr position
+ // to fix the offset if necessary
+ // the isd (inner shift dim) should skip to next threshold or endpoint
+ // all dimensions to the left increment by 1 when a digit is carried
+ // all dimensions to the right remain set to 0
+ // +1 +1 +1 +isd_indx_skip
+ // indicies = [i, j, k, l, 0, 0]
+ // ^isd
+ for (int j = isd; j >= 0; j--) {
+ int inc = 1;
+ if (j == isd) inc = isd_indx_skip;
+ const int indx = (indicies[j] + inc) % dim_size[j];
+ indicies[j] = indx;
+ if (indx != 0) {
+ if (indx == threshold[j]) {
+ out_ptr -= dim_range[j]; // now wraps around
+ }
+ break; // indx != 0 don't need to carry
+ } else if (threshold[j] != 0) { // if threshold is 0 shift is 0
+ out_ptr += dim_range[j]; // indx became 0 so reverse wrap around
+ }
+ }
+
+ // set isd_indx_skip and group_size for next iteration
+ if (indicies[isd] < threshold[isd]) {
+ isd_indx_skip = threshold[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride;
+ } else {
+ isd_indx_skip = dim_size[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride;
+ }
+ }
+ };
+ // Shard
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ const int64 ave_group_size = dim_range[isd] / 2;
+ const int total_work = 2 * num_elements / std::max<int>(dim_range[isd], 1);
+ // 25000 - expiramentally determined with float and bool types
+ const int cost_per_group = 25000 * sizeof(T) * ave_group_size;
+ Shard(worker_threads->num_threads, worker_threads->workers, total_work,
+ cost_per_group, std::move(work));
+}
+
+template <typename Device, typename T, typename Tshift, typename Taxis>
+class RollOp : public OpKernel {
+ public:
+ explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& input = context->input(0);
+ const Tensor& shift = context->input(1);
+ const Tensor& axis = context->input(2);
+
+ auto shift_flat = shift.flat<Tshift>();
+ auto axis_flat = axis.flat<Taxis>();
+
+ OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
+ errors::InvalidArgument("input must be 1-D or higher"));
+ OP_REQUIRES(context, shift.shape().dims() <= 1,
+ errors::InvalidArgument(
+ "shift must be a scalar or a 1-D vector. Found: ",
+ shift.shape().DebugString()));
+ OP_REQUIRES(context, axis.shape().dims() <= 1,
+ errors::InvalidArgument(
+ "axis must be a scalar or a 1-D vector. Found: ",
+ axis.shape().DebugString()));
+ OP_REQUIRES(
+ context, shift.shape() == axis.shape(),
+ errors::InvalidArgument("shift and axis must have the same size"));
+ const int64 num_elements = input.NumElements();
+ const int num_shifts = static_cast<int>(shift_flat.size());
+ const int num_dims = input.dims();
+
+ // if there are any duplicate axes, shift_mod_sum will have the
+ // total modulo sum of shifts for each dimension
+ gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0);
+ for (int i = 0; i < num_shifts; i++) {
+ const int axis = axis_flat(i);
+ OP_REQUIRES(context, axis < num_dims,
+ errors::InvalidArgument("axis ", axis, " is out of range"));
+ const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
+ const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
+ // modulo that works with negatives: ((x % y) + y) % y
+ shift_mod_sum[axis] = (sum % ds + ds) % ds;
+ }
+ // the size of each dimension
+ gtl::InlinedVector<int, 4> dim_size(num_dims);
+ // threshold[i] is the index that the roll starts to wrap back to the front
+ gtl::InlinedVector<int, 4> threshold(num_dims);
+ // dim_range is the number of indices over in the flattened tensor
+ // you need to skip in order to make it over from one side of a dimension
+ // to the other. Used to make the shifts wrap around after a threshold.
+ gtl::InlinedVector<int64, 4> dim_range(num_dims);
+ int64 dim_size_prod = 1; // dimension size product
+ // inner shift dimension (inner most shifted dimension)
+ int64 isd = 0;
+ for (int i = num_dims - 1; i >= 0; i--) {
+ if (isd == 0 && shift_mod_sum[i] != 0) isd = i;
+ const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1);
+ dim_size[i] = ds;
+ threshold[i] = (ds - shift_mod_sum[i]) % ds;
+ dim_size_prod *= static_cast<int64>(input.dim_size(i));
+ dim_range[i] = dim_size_prod;
+ }
+
+ Tensor* output = NULL;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ auto input_flat = input.flat<T>().data();
+ auto output_flat = output->flat<T>().data();
+
+ if (std::is_same<Device, CPUDevice>::value) {
+ if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
+ // V2 copies memory in groups instead of element by element
+ DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
+ input_flat, output_flat, threshold, dim_range, isd);
+ } else {
+ // incase memcpy does not work for current data type
+ DoRoll<T>(context, num_elements, num_dims, dim_size, input_flat,
+ output_flat, threshold, dim_range);
+ }
+ }
+ }
+};
+
+// Register the CPU kernels.
+#define REGISTER_CPU(type) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tshift") \
+ .TypeConstraint<int32>("Taxis"), \
+ RollOp<CPUDevice, type, int32, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tshift") \
+ .TypeConstraint<int32>("Taxis"), \
+ RollOp<CPUDevice, type, int64, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tshift") \
+ .TypeConstraint<int64>("Taxis"), \
+ RollOp<CPUDevice, type, int32, int64>) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tshift") \
+ .TypeConstraint<int64>("Taxis"), \
+ RollOp<CPUDevice, type, int64, int64>)
+
+TF_CALL_ALL_TYPES(REGISTER_CPU);
+#undef REGISTER_CPU
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/roll_op_test.cc b/tensorflow/core/kernels/roll_op_test.cc
new file mode 100644
index 0000000000..90b6f8d0f3
--- /dev/null
+++ b/tensorflow/core/kernels/roll_op_test.cc
@@ -0,0 +1,484 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <functional>
+#include <memory>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class RollOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType data_type, DataType index_type) {
+ TF_ASSERT_OK(NodeDefBuilder("myop", "Roll")
+ .Input(FakeInput(data_type))
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(index_type))
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(RollOpTest, ScalarIndices) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
+ test::FillValues<float>(&expected, {2, 3, 4, 0, 1});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ScalarIndices_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({5}), {"a", "b", "c", "d", "e"});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({5}));
+ test::FillValues<string>(&expected, {"c", "d", "e", "a", "b"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ScalarIndices_Complex) {
+ MakeOp(DT_COMPLEX64, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<std::complex<float>>(
+ TensorShape({5}), {std::complex<float>(0, 10), std::complex<float>(1, 11),
+ std::complex<float>(2, 12), std::complex<float>(3, 13),
+ std::complex<float>(4, 14)});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_COMPLEX64, TensorShape({5}));
+ test::FillValues<std::complex<float>>(
+ &expected, {std::complex<float>(2, 12), std::complex<float>(3, 13),
+ std::complex<float>(4, 14), std::complex<float>(0, 10),
+ std::complex<float>(1, 11)});
+ test::ExpectTensorEqual<std::complex<float>>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({3, 5}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int32>(TensorShape({2}), {2, -1});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({3, 5}));
+ test::FillValues<float>(&expected,
+ {6, 7, 8, 9, 5, 11, 12, 13, 14, 10, 1, 2, 3, 4, 0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({3, 5}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
+ "k", "l", "m", "n", "o"});
+ AddInputFromArray<int32>(TensorShape({2}), {2, -1});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({3, 5}));
+ test::FillValues<string>(&expected, {"g", "h", "i", "j", "f", "l", "m", "n",
+ "o", "k", "b", "c", "d", "e", "a"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ AddInputFromArray<int32>(TensorShape({3}), {1, -1, -1});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 3}));
+ test::FillValues<float>(&expected, {10, 11, 9, 7, 8, 6, 4, 5, 3, 1, 2, 0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(
+ TensorShape({2, 2, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ AddInputFromArray<int32>(TensorShape({3}), {1, -1, -1});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({2, 2, 3}));
+ test::FillValues<string>(
+ &expected, {"k", "l", "j", "h", "i", "g", "e", "f", "d", "b", "c", "a"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD64) {
+ MakeOp(DT_FLOAT, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int64>(TensorShape({2}), {-1, 4});
+ AddInputFromArray<int64>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
+ test::FillValues<float>(&expected,
+ {5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 2, 0, 1});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD64_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({5, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
+ "k", "l", "m", "n", "o"});
+ AddInputFromArray<int64>(TensorShape({2}), {-1, 4});
+ AddInputFromArray<int64>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({5, 3}));
+ test::FillValues<string>(&expected, {"f", "d", "e", "i", "g", "h", "l", "j",
+ "k", "o", "m", "n", "c", "a", "b"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD64) {
+ MakeOp(DT_FLOAT, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({4, 1, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ AddInputFromArray<int64>(TensorShape({3}), {4, 3, 2});
+ AddInputFromArray<int64>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 1, 3}));
+ test::FillValues<float>(&expected, {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD64_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<string>(
+ TensorShape({4, 1, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ AddInputFromArray<int64>(TensorShape({3}), {4, 3, 2});
+ AddInputFromArray<int64>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({4, 1, 3}));
+ test::FillValues<string>(
+ &expected, {"b", "c", "a", "e", "f", "d", "h", "i", "g", "k", "l", "j"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroShift_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 3}));
+ test::FillValues<float>(&expected, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroShift_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(
+ TensorShape({2, 2, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({2, 2, 3}));
+ test::FillValues<string>(
+ &expected, {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroSize_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 0, 0}), {});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 0, 0}));
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroSize_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({5, 0, 0}), {});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({5, 0, 0}));
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, OneSize_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({1, 1, 1}), {5});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1}));
+ test::FillValues<float>(&expected, {5});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, OneSize_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({1, 1, 1}), {"a"});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({1, 1, 1}));
+ test::FillValues<string>(&expected, {"a"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, MultiShifts_TwoD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({3, 5}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int32>(TensorShape({4}), {-2, 2, -1, 1});
+ AddInputFromArray<int32>(TensorShape({4}), {1, 0, 0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({3, 5}));
+ test::FillValues<float>(&expected,
+ {11, 12, 13, 14, 10, 1, 2, 3, 4, 0, 6, 7, 8, 9, 5});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, MultiShifts_TwoD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({3, 5}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
+ "k", "l", "m", "n", "o"});
+ AddInputFromArray<int32>(TensorShape({4}), {-2, 2, -1, 1});
+ AddInputFromArray<int32>(TensorShape({4}), {1, 0, 0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({3, 5}));
+ test::FillValues<string>(&expected, {"l", "m", "n", "o", "k", "b", "c", "d",
+ "e", "a", "g", "h", "i", "j", "f"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Error_InputMustBeVectorOrHigher) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({}), {7});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString()).contains("input must be 1-D or higher"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_AxisMustBeScalarOrVector) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({1, 2}), {0, 1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("axis must be a scalar or a 1-D vector"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_ShiftMustBeScalarOrVector) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({1, 2}), {0, 1});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("shift must be a scalar or a 1-D vector"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_ShiftAndAxisMustBeSameSize) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({1}), {1});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("shift and axis must have the same size"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_AxisOutOfRange) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({4}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString()).contains("is out of range")) << s;
+}
+
+// isd - (inner shift dimension) The inner most dimension to be shifted.
+// All outer dimensions will also be shifted for testing.
+static Graph* RollGraph(const TensorShape& shape, int isd) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor input(DT_FLOAT, shape);
+ input.flat<float>().setRandom();
+ const int dims = static_cast<int>(input.dims());
+ Tensor shift(DT_INT32, TensorShape({dims}));
+ for (int i = 0; i < dims; i++) {
+ // shift the inner shift dimension and all outer dimensions
+ shift.flat<int32>()(i) = (i <= isd) ? 2 : 0;
+ }
+ Tensor axis(DT_INT32, TensorShape({dims}));
+ for (int i = 0; i < dims; i++) {
+ axis.flat<int32>()(i) = i;
+ }
+ test::graph::Roll(g, test::graph::Constant(g, input),
+ test::graph::Constant(g, shift),
+ test::graph::Constant(g, axis));
+ return g;
+}
+
+#define BM_ROLL_OUTER(DEVICE) \
+ static void BM_##DEVICE##_roll_outer(int iters, int rows, int columns) { \
+ TensorShape shape{rows, columns}; \
+ const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
+ testing::ItemsProcessed(num_items); \
+ testing::BytesProcessed(num_items * sizeof(float)); \
+ testing::UseRealTime(); \
+ test::Benchmark(#DEVICE, RollGraph(shape, 0)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_roll_outer) \
+ ->ArgPair(256, 256) \
+ ->ArgPair(512, 512) \
+ ->ArgPair(1024, 1024) \
+ ->ArgPair(2048, 2048)
+
+#define BM_ROLL_ALL(DEVICE) \
+ static void BM_##DEVICE##_roll_all(int iters, int rows, int columns) { \
+ TensorShape shape{rows, columns}; \
+ const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
+ testing::ItemsProcessed(num_items); \
+ testing::BytesProcessed(num_items * sizeof(float)); \
+ testing::UseRealTime(); \
+ test::Benchmark(#DEVICE, RollGraph(shape, 1)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_roll_all) \
+ ->ArgPair(256, 256) \
+ ->ArgPair(512, 512) \
+ ->ArgPair(1024, 1024) \
+ ->ArgPair(2048, 2048)
+
+BM_ROLL_OUTER(cpu);
+BM_ROLL_ALL(cpu);
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/manip_ops.cc b/tensorflow/core/ops/manip_ops.cc
new file mode 100644
index 0000000000..95b4774fe6
--- /dev/null
+++ b/tensorflow/core/ops/manip_ops.cc
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Roll")
+ .Input("input: T")
+ .Input("shift: Tshift")
+ .Input("axis: Taxis")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tshift: {int32,int64}")
+ .Attr("Taxis: {int32,int64}")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
+} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index a323d5bc39..c73d6c37ee 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -76,6 +76,7 @@ py_library(
":layers",
":lib",
":list_ops",
+ ":manip_ops",
":math_ops",
":metrics",
":nn",
@@ -1395,6 +1396,14 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "manip_ops_gen",
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ "//tensorflow/python/kernel_tests:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "math_ops_gen",
visibility = [
"//learning/brain/google/python/ops:__pkg__",
@@ -1726,6 +1735,8 @@ py_library(
":linalg_grad",
":linalg_ops",
":logging_ops",
+ ":manip_grad",
+ ":manip_ops",
":math_grad",
":math_ops",
":platform",
@@ -1849,6 +1860,29 @@ py_library(
)
py_library(
+ name = "manip_grad",
+ srcs = ["ops/manip_grad.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":control_flow_ops",
+ ":framework_for_generated_wrappers",
+ ":manip_ops",
+ ],
+)
+
+py_library(
+ name = "manip_ops",
+ srcs = ["ops/manip_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dtypes",
+ ":framework_ops",
+ ":manip_ops_gen",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "logging_ops",
srcs = ["ops/logging_ops.py"],
srcs_version = "PY2AND3",
@@ -2310,6 +2344,8 @@ py_library(
":linalg_ops",
":logging_ops",
":lookup_ops",
+ ":manip_grad",
+ ":manip_ops",
":math_grad",
":math_ops",
":numerics",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index bc9ddec2a5..ea7604d30f 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -84,6 +84,7 @@ from tensorflow.python.feature_column import feature_column_lib as feature_colum
from tensorflow.python.layers import layers
from tensorflow.python.ops import bitwise_ops as bitwise
from tensorflow.python.ops import image_ops as image
+from tensorflow.python.ops import manip_ops as manip
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn
from tensorflow.python.ops import sets
@@ -241,6 +242,7 @@ _allowed_symbols.extend([
'linalg',
'logging',
'losses',
+ 'manip',
'metrics',
'newaxis',
'nn',
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index c87b7652ad..3a6058054b 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1602,6 +1602,19 @@ cuda_py_test(
)
cuda_py_test(
+ name = "manip_ops_test",
+ size = "small",
+ srcs = ["manip_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:manip_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+cuda_py_test(
name = "matmul_op_test",
size = "small",
srcs = ["matmul_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
new file mode 100644
index 0000000000..3044b21aa4
--- /dev/null
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -0,0 +1,137 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for manip_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import manip_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.platform import test as test_lib
+
+import numpy as np
+
+# pylint: disable=g-import-not-at-top
+try:
+ from distutils.version import StrictVersion as Version
+ # numpy.roll for multiple shifts was introduced in numpy version 1.12.0
+ NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version('1.12.0')
+except ImportError:
+ NP_ROLL_CAN_MULTISHIFT = False
+# pylint: enable=g-import-not-at-top
+
+class RollTest(test_util.TensorFlowTestCase):
+ def _testRoll(self, np_input, shift, axis):
+ expected_roll = np.roll(np_input, shift, axis)
+ with self.test_session():
+ roll = manip_ops.roll(np_input, shift, axis)
+ self.assertAllEqual(roll.eval(), expected_roll)
+
+ def _testGradient(self, np_input, shift, axis):
+ with self.test_session():
+ inx = constant_op.constant(np_input.tolist())
+ xs = list(np_input.shape)
+ y = manip_ops.roll(inx, shift, axis)
+ # Expected y's shape to be the same
+ ys = xs
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ inx, xs, y, ys, x_init_value=np_input)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def _testAll(self, np_input, shift, axis):
+ self._testRoll(np_input, shift, axis)
+ if np_input.dtype == np.float32:
+ self._testGradient(np_input, shift, axis)
+
+ def testIntTypes(self):
+ for t in [np.int32, np.int64]:
+ self._testAll(np.random.randint(-100, 100, (5)).astype(t), 3, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ self._testAll(np.random.randint(-100, 100, (4, 4, 3)).astype(t),
+ [1, -2, 3], [0, 1, 2])
+ self._testAll(np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t),
+ [0, 1, -2], [1, 2, 3])
+
+ def testFloatTypes(self):
+ for t in [np.float32, np.float64]:
+ self._testAll(np.random.rand(5).astype(t), 2, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ self._testAll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0])
+ self._testAll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2])
+
+ def testComplexTypes(self):
+ for t in [np.complex64, np.complex128]:
+ x = np.random.rand(4, 4).astype(t)
+ self._testAll(x + 1j * x, 2, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ x = np.random.rand(2, 5).astype(t)
+ self._testAll(x + 1j * x, [1, 2], [1, 0])
+ x = np.random.rand(3, 2, 1, 1).astype(t)
+ self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
+
+
+ def testRollInputMustVectorHigherRaises(self):
+ tensor = 7
+ shift = 1
+ axis = 0
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "input must be 1-D or higher"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollAxisMustBeScalarOrVectorRaises(self):
+ tensor = [[1, 2],
+ [3, 4]]
+ shift = 1
+ axis = [[0, 1]]
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "axis must be a scalar or a 1-D vector"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollShiftMustBeScalarOrVectorRaises(self):
+ tensor = [[1, 2],
+ [3, 4]]
+ shift = [[0, 1]]
+ axis = 1
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "shift must be a scalar or a 1-D vector"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollShiftAndAxisMustBeSameSizeRaises(self):
+ tensor = [[1, 2],
+ [3, 4]]
+ shift = [1]
+ axis = [0, 1]
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "shift and axis must have the same size"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollAxisOutOfRangeRaises(self):
+ tensor = [1, 2]
+ shift = 1
+ axis = 1
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "is out of range"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+if __name__ == "__main__":
+ test_lib.main()
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 314726ede6..230b6c5946 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import image_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import
from tensorflow.python.ops import logging_ops # pylint: disable=unused-import
+from tensorflow.python.ops import manip_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
diff --git a/tensorflow/python/ops/manip_grad.py b/tensorflow/python/ops/manip_grad.py
new file mode 100644
index 0000000000..573e8c0a0d
--- /dev/null
+++ b/tensorflow/python/ops/manip_grad.py
@@ -0,0 +1,32 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Gradients for operators defined in manip_ops.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import manip_ops
+
+
+@ops.RegisterGradient("Roll")
+def _RollGrad(op, grad):
+ # The gradient is just the roll reversed
+ shift = op.inputs[1]
+ axis = op.inputs[2]
+ roll_grad = manip_ops.roll(grad, -shift, axis)
+ return roll_grad, None, None
diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py
new file mode 100644
index 0000000000..c5f39784f4
--- /dev/null
+++ b/tensorflow/python/ops/manip_ops.py
@@ -0,0 +1,36 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators for manipulating tensors.
+
+@@roll
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
+from tensorflow.python.util.all_util import remove_undocumented
+
+# pylint: disable=protected-access
+def roll(input, shift, axis):
+ return _gen_manip_ops.roll(input, shift, axis)
+
+roll.__doc__ = _gen_manip_ops.roll.__doc__
+# pylint: enable=protected-access
+
+_allowed_symbols = ['roll']
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 30bf4e4ef1..737b923415 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -26,6 +26,7 @@ import sys as _sys
from tensorflow.python.ops import array_grad
from tensorflow.python.ops import data_flow_grad
from tensorflow.python.ops import math_grad
+from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import spectral_grad
from tensorflow.python.ops import state_grad
@@ -59,6 +60,7 @@ from tensorflow.python.ops.logging_ops import Print
from tensorflow.python.ops.logging_ops import get_summary_op
from tensorflow.python.ops.lookup_ops import initialize_all_tables
from tensorflow.python.ops.lookup_ops import tables_initializer
+from tensorflow.python.ops.manip_ops import *
from tensorflow.python.ops.math_ops import *
from tensorflow.python.ops.numerics import *
from tensorflow.python.ops.parsing_ops import *
@@ -105,6 +107,7 @@ from tensorflow.python.ops import init_ops as _init_ops
from tensorflow.python.ops import io_ops as _io_ops
from tensorflow.python.ops import linalg_ops as _linalg_ops
from tensorflow.python.ops import logging_ops as _logging_ops
+from tensorflow.python.ops import manip_ops as _manip_ops
from tensorflow.python.ops import math_ops as _math_ops
from tensorflow.python.ops import numerics as _numerics
from tensorflow.python.ops import parsing_ops as _parsing_ops
@@ -280,6 +283,7 @@ remove_undocumented(__name__, _allowed_symbols,
_io_ops,
_linalg_ops,
_logging_ops,
+ _manip_ops,
_math_ops,
_numerics,
_parsing_ops,
diff --git a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt
new file mode 100644
index 0000000000..0b84165285
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.manip"
+tf_module {
+ member_method {
+ name: "roll"
+ argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index dc7c3a2f45..e8890e9cc0 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -397,6 +397,10 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "manip"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "metrics"
mtype: "<type \'module\'>"
}