aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/BUILD46
-rw-r--r--tensorflow/core/kernels/compare_and_bitpack_op.cc15
-rw-r--r--tensorflow/core/kernels/decode_bmp_op.cc19
-rw-r--r--tensorflow/core/kernels/fractional_pool_common.h2
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc13
-rw-r--r--tensorflow/core/kernels/mkl_avgpooling_op.cc31
-rw-r--r--tensorflow/core/kernels/mkl_concat_op.cc6
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc6
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc6
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc7
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h6
-rw-r--r--tensorflow/core/kernels/mkl_cwise_ops_common.cc2
-rw-r--r--tensorflow/core/kernels/mkl_fused_batch_norm_op.cc6
-rw-r--r--tensorflow/core/kernels/mkl_identity_op.cc4
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc62
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc6
-rw-r--r--tensorflow/core/kernels/mkl_maxpooling_op.cc10
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.cc6
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h8
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc8
-rw-r--r--tensorflow/core/kernels/mkl_reshape_op.cc6
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc4
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.h4
-rw-r--r--tensorflow/core/kernels/roll_op.cc334
-rw-r--r--tensorflow/core/kernels/roll_op_test.cc484
-rw-r--r--tensorflow/core/kernels/unravel_index_op.cc122
26 files changed, 1151 insertions, 72 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index fd99409c9b..e7192ec42f 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -629,6 +629,7 @@ cc_library(
":transpose_op",
":unique_op",
":unpack_op",
+ ":unravel_index_op",
":where_op",
],
)
@@ -884,6 +885,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "unravel_index_op",
+ prefix = "unravel_index_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "where_op",
srcs = ["where_op.cc"],
hdrs = ["where_op.h"],
@@ -2582,6 +2589,45 @@ tf_cc_tests(
],
)
+cc_library(
+ name = "manip",
+ deps = [
+ ":roll_op",
+ ],
+)
+
+MANIP_DEPS = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:manip_ops_op_lib",
+ "//third_party/eigen3",
+]
+
+tf_kernel_library(
+ name = "roll_op",
+ prefix = "roll_op",
+ deps = MANIP_DEPS,
+)
+
+tf_cc_test(
+ name = "roll_op_test",
+ size = "small",
+ srcs = ["roll_op_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ ":roll_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
MATH_DEPS = [
":bounds_check",
":fill_functor",
diff --git a/tensorflow/core/kernels/compare_and_bitpack_op.cc b/tensorflow/core/kernels/compare_and_bitpack_op.cc
index 9f626a274a..224fe534e3 100644
--- a/tensorflow/core/kernels/compare_and_bitpack_op.cc
+++ b/tensorflow/core/kernels/compare_and_bitpack_op.cc
@@ -110,7 +110,19 @@ struct ComputeShard<T,
typename TTypes<bool>::ConstMatrix input,
typename TTypes<uint8>::Matrix output, bool /*thresh*/, int64 start,
int64 limit) {
- // NOTE(ebrevdo): This assumes memory is little-endian.
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ for (int64 i = start; i < limit; ++i) {
+ uint8* out = output.data() + i;
+ const int64 block = *reinterpret_cast<const int64*>(input.data() + 8 * i);
+ *out = ((((block & (1LL << (7 * 8))) >> (7 * 8 - 7))) |
+ (((block & (1LL << (6 * 8))) >> (6 * 8 - 6))) |
+ (((block & (1LL << (5 * 8))) >> (5 * 8 - 5))) |
+ (((block & (1LL << (4 * 8))) >> (4 * 8 - 4))) |
+ (((block & (1LL << (3 * 8))) >> (3 * 8 - 3))) |
+ (((block & (1LL << (2 * 8))) >> (2 * 8 - 2))) |
+ (((block & (1LL << 8)) >> (1 * 8 - 1))) | (((block & (1LL)))));
+ }
+#else
for (int64 i = start; i < limit; ++i) {
uint8* out = output.data() + i;
const int64 block = *reinterpret_cast<const int64*>(input.data() + 8 * i);
@@ -123,6 +135,7 @@ struct ComputeShard<T,
(((block & (1LL << (2 * 8))) >> (2 * 8 - 5))) |
(((block & (1LL << 8)) >> (1 * 8 - 6))) | (((block & (1LL)) << 7)));
}
+#endif
}
};
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index c778278e8f..b7d120a617 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -39,6 +39,13 @@ class DecodeBmpOp : public OpKernel {
errors::InvalidArgument("channels must be 0, 1, 3 or 4, got ",
channels_));
}
+ inline int32 ByteSwapInt32ForBigEndian(int32 x) {
+#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+ return le32toh(x);
+#else
+ return x;
+#endif
+ }
void Compute(OpKernelContext* context) override {
const Tensor& contents = context->input(0);
@@ -56,14 +63,18 @@ class DecodeBmpOp : public OpKernel {
input.size(), " bytes"));
const uint8* img_bytes = reinterpret_cast<const uint8*>(input.data());
- const int32 header_size = internal::SubtleMustCopy(
+ int32 header_size_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 10)));
- const int32 width = internal::SubtleMustCopy(
+ const int32 header_size = ByteSwapInt32ForBigEndian(header_size_);
+ int32 width_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 18)));
- const int32 height = internal::SubtleMustCopy(
+ const int32 width = ByteSwapInt32ForBigEndian(width_);
+ int32 height_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 22)));
- const int32 bpp = internal::SubtleMustCopy(
+ const int32 height = ByteSwapInt32ForBigEndian(height_);
+ int32 bpp_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 28)));
+ const int32 bpp = ByteSwapInt32ForBigEndian(bpp_);
if (channels_) {
OP_REQUIRES(context, (channels_ == bpp / 8),
diff --git a/tensorflow/core/kernels/fractional_pool_common.h b/tensorflow/core/kernels/fractional_pool_common.h
index df0bbbfa06..2d7a230fc0 100644
--- a/tensorflow/core/kernels/fractional_pool_common.h
+++ b/tensorflow/core/kernels/fractional_pool_common.h
@@ -57,7 +57,7 @@ static inline void RandomShuffle(Iter first, Iter last, const Random& uniform) {
// * sum(generated_diff_pooling_sequence) = input_length
// * Let's define floor(input_length / output_length) = K, then
// K <= generated_diff_pooling_sequence[i] <= K+1
-// For example, when input_length = 10, output_length = 6, the followings are
+// For example, when input_length = 10, output_length = 6, the following are
// valid pooling sequence:
// * [1, 2, 2, 1, 2, 2]
// * [1, 1, 2, 2, 2, 2]
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
index 89d37d2f87..b539b00009 100644
--- a/tensorflow/core/kernels/mkl_aggregate_ops.cc
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::stream;
using mkldnn::sum;
@@ -37,7 +37,7 @@ using mkldnn::sum;
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename Device, typename T>
class MklAddNOp : public OpKernel {
@@ -285,7 +285,7 @@ class MklAddNOp : public OpKernel {
} MklAddNOpContext;
};
-#else // INTEL_MKL_DNN
+#else // INTEL_MKL_ML
template <typename Device, typename T>
class MklAddNOp : public OpKernel {
public:
@@ -317,8 +317,11 @@ class MklAddNOp : public OpKernel {
: src2_tensor.dims();
// if the shapes of two tensors are not same raise op error
TensorShape src1_shape, src2_shape;
- src1_shape = src1_tensor.shape();
- src2_shape = src2_tensor.shape();
+ src1_shape = input1_in_mkl_format ? src1_mkl_shape.GetTfShape()
+ : src1_tensor.shape();
+ src2_shape = input2_in_mkl_format ? src2_mkl_shape.GetTfShape()
+ : src2_tensor.shape();
+
if (!src1_shape.IsSameSize(src2_shape)) {
ctx->SetStatus(errors::InvalidArgument(
"Inputs to operation ", this->name(), " of type ",
diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc
index a7c569ee05..d545d34fdf 100644
--- a/tensorflow/core/kernels/mkl_avgpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc
@@ -24,7 +24,7 @@
#include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::algorithm;
using mkldnn::engine;
@@ -40,8 +40,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-// For now, MKL-ML is default. So making MKL-DNN not a default choice.
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename Device, typename T>
class MklAvgPoolingOp : public OpKernel {
@@ -429,7 +428,7 @@ class MklAvgPoolingGradOp : public OpKernel {
TensorFormat data_format_;
}; // MklAvgPoolingGradOp
-#else // INTEL_MKL_DNN is defined
+#else
template <typename Device, typename T>
class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
@@ -466,6 +465,28 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);
+ // If input is an empty tensor, allocate an empty output tensor and return
+ if (input_tensor.NumElements() == 0) {
+ MklDnnShape output_mkl_shape;
+ output_mkl_shape.SetMklTensor(false);
+ TensorShape output_tf_shape;
+ if (pool_params.data_format == TensorFormat::FORMAT_NCHW) {
+ output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
+ } else {
+ memory::dims output_dims_NHWC_order;
+ output_dims_NHWC_order = {pool_params.tensor_in_batch,
+ static_cast<int>(pool_params.out_height),
+ static_cast<int>(pool_params.out_width),
+ pool_params.out_depth};
+ output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order);
+ }
+ const int kOutputIndex = 0;
+ AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor,
+ output_tf_shape, output_mkl_shape);
+ CHECK_NOTNULL(output_tensor);
+ return;
+ }
+
// If input is in Mkl layout, then just get the memory format from it
// directly, instead of using input data_format to AvgPool.
if (dnn_shape_input.IsMklTensor()) {
@@ -678,7 +699,7 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
}
}; // MklAvgPoolingGradOp
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
REGISTER_KERNEL_BUILDER(Name("_MklAvgPool")
.Device(DEVICE_CPU)
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
index 7da63604d2..f1f267e849 100644
--- a/tensorflow/core/kernels/mkl_concat_op.cc
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -30,7 +30,7 @@ limitations under the License.
#include "mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::concat;
@@ -62,7 +62,7 @@ class EigenConcatBaseOp : public OpKernel {
// we need to have empty Compute because Compute is pure virtual function.
void Compute(OpKernelContext* c) {}
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
void Compute(OpKernelContext* c, const std::vector<Tensor>& values) {
const Tensor* concat_dim_tensor;
@@ -230,7 +230,7 @@ class EigenConcatBaseOp : public OpKernel {
#endif
};
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
// --------------------------------------------------------------------------
// Mkl Concat Op
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index ef3f8cfec1..1401bc65a4 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -42,7 +42,7 @@ limitations under the License.
#include "mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::convolution_backward_weights;
@@ -55,7 +55,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename Device, class T>
class MklConv2DCustomBackpropFilterOp : public OpKernel {
@@ -655,7 +655,7 @@ class MklConv2DCustomBackpropFilterOp
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
#undef REGISTER_MKL_FILTER_KERNELS
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index a6745489f4..eeed009531 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -44,7 +44,7 @@ limitations under the License.
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::convolution_backward_data;
@@ -56,7 +56,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename Device, class T>
class MklConv2DCustomBackpropInputOp : public OpKernel {
@@ -493,7 +493,7 @@ class MklConv2DCustomBackpropInputOp
}
};
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
#define REGISTER_MKL_CPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index e44fba754b..2953426d58 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -41,7 +41,8 @@ limitations under the License.
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
+
#include "mkldnn.hpp"
using mkldnn::prop_kind;
@@ -58,8 +59,8 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-// For now, MKL-ML is default. So making MKL-DNN not a default choice.
-#ifndef INTEL_MKL_DNN
+// MKL-DNN is now default. MKL-ML must be specified explicitly.
+#ifdef INTEL_MKL_ML
template <typename Device, typename T, bool biasEnabled>
class MklConv2DOp : public OpKernel {
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 8b65eaea0d..9dd88221a8 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -40,7 +40,7 @@ limitations under the License.
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::prop_kind;
@@ -52,7 +52,7 @@ using mkldnn::convolution_forward;
namespace tensorflow {
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
class MklDnnConvUtil {
protected:
@@ -553,7 +553,7 @@ class MklConv2DBackpropCommonOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
};
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
/////////////////////////////////////////////////////////////////////
/// Dummy Mkl op that is just used for operators that are intermediate
diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
index c065724e0d..58f0c30f32 100644
--- a/tensorflow/core/kernels/mkl_cwise_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0(the "License");
you may not use this file except in compliance with the License.
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index 0b6d838e09..8313224d7f 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::batch_normalization_backward;
@@ -41,7 +41,7 @@ using mkldnn::use_scale_shift;
namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename Device, typename T>
class MklFusedBatchNormOp : public OpKernel {
@@ -683,7 +683,7 @@ class MklFusedBatchNormGradOp : public OpKernel {
};
#endif
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
template <typename Device, typename T>
class MklFusedBatchNormOp : public OpKernel {
diff --git a/tensorflow/core/kernels/mkl_identity_op.cc b/tensorflow/core/kernels/mkl_identity_op.cc
index 9ee27ee21c..6c027f8e72 100644
--- a/tensorflow/core/kernels/mkl_identity_op.cc
+++ b/tensorflow/core/kernels/mkl_identity_op.cc
@@ -28,14 +28,14 @@ limitations under the License.
#include "mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
#endif
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename Device, typename T>
class MklIdentityOp : public OpKernel {
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index 73d41efce1..5a8799ae93 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/kernels/mkl_tfconv_op.h"
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::stream;
@@ -59,7 +59,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
// convert the TF format input to MKL format
///////////////////////////////////////////////////////////
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename Device, typename T>
class MklInputConversionOp : public OpKernel {
public:
@@ -293,14 +293,58 @@ class MklInputConversionOp : public OpKernel {
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// If both inputs are in MKL format
if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
- // If both have the same shape, pass them through
if (tf_shapes_are_same) {
- VLOG(1) << "MklInputConversionOp: No conversion needed, "
- << "copying MKL inputs with identical shapes to output";
-
- ForwardMklTensorInToOut(context, 0, 0);
- ForwardMklTensorInToOut(context, 1, 1);
- return;
+ auto input0_md = input_shape_0.GetMklLayout();
+ auto input1_md = input_shape_1.GetMklLayout();
+
+ // If both have the same shape and same format, pass them through
+ if (input0_md.data.format == input1_md.data.format) {
+ VLOG(1) << "MklInputConversionOp: No conversion needed, "
+ << "copying MKL inputs with identical shapes to output";
+
+ ForwardMklTensorInToOut(context, 0, 0);
+ ForwardMklTensorInToOut(context, 1, 1);
+ return;
+ } else {
+ VLOG(1) << "MklInputConversionOp: Shape is same, but format is "
+ "different, "
+ << "need to convert to same format";
+
+ // Convert input0, and keep input1 unchanged
+ // Create MklDnnShape for output mkl tensor based on input0
+ Tensor* tensor_out;
+ MklDnnShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(true);
+ mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
+ mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(),
+ input_shape_0.GetSizesAsMklDnnDims(),
+ input_shape_0.GetTfDataFormat());
+
+ // Get MKL layout from input1 as destination layout
+ mkl_output_mkl_shape.SetMklLayout(&input1_md);
+
+ // Create output Mkl tensor for index 0
+ AllocateOutputSetMklShape(context, 0, &tensor_out,
+ input_tensor_0.shape(),
+ mkl_output_mkl_shape);
+
+ // Create MklDnnData object for input0 tesnsor
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> input(&cpu_engine);
+ input.SetUsrMem(input0_md, &input_tensor_0);
+
+ // Create reorder from input0's layout to input1's layout
+ std::vector<primitive> net;
+ CHECK_EQ(input.CheckReorderToOpMem(
+ memory::primitive_desc(input1_md, cpu_engine),
+ tensor_out, &net),
+ true);
+ stream(stream::kind::eager).submit(net).wait();
+
+ // Input1 will be passed through
+ ForwardMklTensorInToOut(context, 1, 1);
+ return;
+ }
}
// Sanity check
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
index a8b45004b7..5f0a12a1fb 100644
--- a/tensorflow/core/kernels/mkl_lrn_op.cc
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/core/util/work_sharder.h"
#endif
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::lrn_across_channels;
using mkldnn::lrn_backward;
@@ -67,7 +67,7 @@ void GetBandMatrix(int depth, int depth_radius,
} // namespace
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename T>
class MklLRNOp : public OpKernel {
@@ -1343,7 +1343,7 @@ class MklLRNGradOp : public OpKernel {
float beta_;
};
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
#define REGISTER_MKL_LRN_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("_MklLRN") \
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index 0de27ccd60..14607f26e0 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include <algorithm>
#include "mkldnn.hpp"
using mkldnn::algorithm;
@@ -39,8 +39,8 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-// For now, MKL-ML is default. So making MKL-DNN not a default choice.
-#ifndef INTEL_MKL_DNN
+// MKL-DNN is now default. MKL-ML must be specified explicitly.
+#ifdef INTEL_MKL_ML
// An implementation of MaxPooling (forward).
template <typename Device, typename T>
@@ -494,7 +494,7 @@ class MklMaxPoolingGradOp : public OpKernel {
bool workspace_enabled_;
}; // MklMaxPoolingGradOp
-#else // INTEL_MKL_DNN is defined
+#else
// An implementation of MaxPooling (forward).
template <typename Device, typename T>
@@ -793,7 +793,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
}
}; // MklMaxPoolingGradOp
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
REGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
.Device(DEVICE_CPU)
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
index ef8597b057..5ef6ce2a57 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -42,7 +42,7 @@ void MklPoolParameters::Init(OpKernelContext* context,
Init(context, ksize, stride, padding, data_format);
}
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
// Initialization for MKL format
void MklPoolParameters::Init(OpKernelContext* context,
const std::vector<int32>& ksize,
@@ -72,7 +72,7 @@ void MklPoolParameters::Init(OpKernelContext* context,
Init(context, ksize, stride, padding, data_format);
}
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
// Common Initialization for TensorFlow and MKL formats
void MklPoolParameters::Init(OpKernelContext* context,
const std::vector<int32>& ksize,
@@ -107,7 +107,7 @@ void MklPoolParameters::Init(OpKernelContext* context,
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
tensor_in_cols, window_cols, col_stride,
padding, &out_width, &pad_left, &pad_right));
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
// TF can work with int64, but mkldnn only supports int32
// Fail if the height or width are greater than MAX_INT
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index 880e45ab1e..279167aba2 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::memory;
using mkldnn::pooling_backward;
@@ -85,7 +85,7 @@ struct MklPoolParameters {
void Init(OpKernelContext* context, const std::vector<int32>& ksize,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format, const TensorShape& tensor_in_shape);
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
void Init(OpKernelContext* context, const std::vector<int32>& ksize,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format, const MklShape* mkl_in_shape);
@@ -102,7 +102,7 @@ struct MklPoolParameters {
TensorFormat data_format);
};
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
template <class T>
class MklPoolingOpBase : public OpKernel {
@@ -395,7 +395,7 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
return grad_reorder_needed ? target_diff_dst_md : original_input_grad_md;
}
};
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
//-------------------------------------------------------------------
// Utility functions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 873aca30ca..51db3991e2 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/platform/default/logging.h"
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::algorithm;
@@ -58,7 +58,7 @@ struct MklReluHelpers {
}
};
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
template <typename Device, typename T>
class MklReluOp : public OpKernel {
@@ -368,7 +368,7 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
mkl_context.MklCleanup();
}
-#else // INTEL_MKL_DNN
+#else // INTEL_MKL_ML
template <typename Device, typename T, algorithm alg_kind>
class MklReluOpBase : public OpKernel {
@@ -849,7 +849,7 @@ class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> {
MklReluGradOp<CPUDevice, type>);
TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
// register dnn kernels for supported operations and supported types
#define REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc
index 7d471e1e4c..5dbc4a2709 100644
--- a/tensorflow/core/kernels/mkl_reshape_op.cc
+++ b/tensorflow/core/kernels/mkl_reshape_op.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::stream;
#endif
@@ -40,7 +40,7 @@ class MklReshapeOp : public OpKernel {
public:
explicit MklReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
void Compute(OpKernelContext* context) override {
const Tensor& input = MklGetInput(context, 0);
const Tensor& sizes = MklGetInput(context, 1);
@@ -312,7 +312,7 @@ class MklReshapeOp : public OpKernel {
}
}
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
private:
const int kInputSlotIdx = 0;
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index c46eabdde1..aceef1e234 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -15,7 +15,7 @@ limitations under the License.
// See docs in ../ops/nn_ops.cc.
#ifdef INTEL_MKL
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
@@ -156,5 +156,5 @@ TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
} // namespace tensorflow
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h
index c4d5a45d3c..5fafa14b5d 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.h
+++ b/tensorflow/core/kernels/mkl_tfconv_op.h
@@ -35,7 +35,7 @@ limitations under the License.
#include "mkl_dnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
using mkldnn::stream;
#endif
@@ -61,7 +61,7 @@ class MklToTfOp : public OpKernel {
VLOG(1) << "MKLToTFConversion complete successfully.";
}
-#ifdef INTEL_MKL_DNN
+#ifndef INTEL_MKL_ML
static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context,
string data_format_str, DataType op_data_type,
bool has_avx512f, uint input_number) {
diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc
new file mode 100644
index 0000000000..bcbdbee058
--- /dev/null
+++ b/tensorflow/core/kernels/roll_op.cc
@@ -0,0 +1,334 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/register_types_traits.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+#define EIGEN_USE_THREADS
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+// dim_size - the size of each dimension
+// dim_range - the number of indices over in the flattened tensor
+// you need to skip in order to make it over from one side of a dimension
+// to the other. Used to make the shifts wrap around after a threshold.
+// threshold - the index for each dimension that the roll starts to wrap
+// back to the front
+template <typename T>
+void DoRoll(OpKernelContext* context, const int64 num_elements,
+ const int num_dims, const gtl::ArraySlice<int>& dim_size,
+ const T* input, T* output, const gtl::ArraySlice<int>& threshold,
+ const gtl::ArraySlice<int64>& dim_range) {
+ auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range](
+ int64 start, int64 end) {
+ // array of indices for each dimension
+ gtl::InlinedVector<int, 4> indices(num_dims);
+ int offset = 0; // the shift along the flattened tensor for current element
+ // initialize indices and offset
+ for (int i = 0; i < num_dims; i++) {
+ // stride is the number of indices over in the flattened tensor
+ // you need to skip in order to make it over to an adjacent element
+ // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
+ const int64 stride = dim_range[i] / dim_size[i];
+ const int shift = dim_size[i] - threshold[i];
+ const int indx = (start / stride) % dim_size[i];
+ indices[i] = indx;
+ // calculate dimension index after the shift
+ const int shifted_indx = (indx + shift) % dim_size[i];
+ offset += (shifted_indx - indx) * stride;
+ }
+
+ for (int64 i = start; i < end; i++) {
+ output[i + offset] = input[i];
+ // create next combination of indices
+ // while at it adjust offset if needed
+ for (int j = num_dims - 1; j >= 0; j--) {
+ const int indx = (indices[j] + 1) % dim_size[j];
+ indices[j] = indx;
+ if (indx != 0) {
+ if (indx == threshold[j]) { // we've reached the threshold
+ // dim_range[j] = threshold[j] + shift[j]
+ // offset = shift[j] + ... other offsets
+ // offset - dim_range[j] = -threshold[j] + ... other offsets
+ // thus we undo our previous offset as well as add a new offset of
+ // -threshold[j] in one operation
+ offset -= dim_range[j]; // now wraps around
+ }
+ break; // indx != 0 don't need to carry
+ } else if (threshold[j] != 0) { // if threshold is 0 shift is 0
+ offset += dim_range[j]; // indx became 0 so reverse wrap around
+ }
+ }
+ }
+ };
+ // Shard
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ // 15 - expiramentally determined with float and bool types
+ const int cost_per_element = 15 * sizeof(T); // rough esitmate
+ Shard(worker_threads->num_threads, worker_threads->workers, num_elements,
+ cost_per_element, std::move(work));
+}
+
+// dim_size - the size of each dimension
+// dim_range - the number of indices over in the flattened tensor
+// you need to skip in order to make it over from one side of a dimension
+// to the other. Used to make the shifts wrap around after a threshold.
+// threshold - the index for each dimension that the roll starts to wrap
+// back to the front
+// isd - inner shift dimension
+template <typename T>
+// Use memcpy to copy memory in groups when the data type supports memcpy
+void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements,
+ const int num_dims, const gtl::ArraySlice<int>& dim_size,
+ const T* input, T* output,
+ const gtl::ArraySlice<int>& threshold,
+ const gtl::ArraySlice<int64>& dim_range,
+ const int64 isd) {
+ auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd](
+ int64 start, int64 end) {
+ // the number of indices over in the flattened tensor you need to skip in
+ // order to make it over from one side of the isd to the other
+ const int64 isd_range = std::max<int>(dim_range[isd], 1);
+ // the distance along the flattend tensor to the next element in the isd
+ const int64 isd_stride = isd_range / std::max<int>(dim_size[isd], 1);
+
+ // start and end represent the i-th group currently so we will convert
+ // them into numbers representing the i-th elements.
+ // there are 2 groups per isd one for all elements before threshold[isd]
+ // and another for all elements after threshold[isd].
+ const int64 start_remainder = (start % 2) * threshold[isd] * isd_stride;
+ const int64 end_remainder = (end % 2) * threshold[isd] * isd_stride;
+ start = (start / 2) * isd_range + start_remainder;
+ end = (end / 2) * isd_range + end_remainder;
+
+ const T* in_ptr = &input[0];
+ T* out_ptr = &output[0];
+ in_ptr += start;
+ out_ptr += start;
+
+ // array of indices for each dimension
+ // indicies = [i, j, k, l, m, n]
+ gtl::InlinedVector<int, 4> indicies(num_dims);
+ // the offset needed to make all inner non-shifting dimensions become 0
+ int64 remainder_offset = 0;
+ // initialize indicies
+ for (int i = 0; i < num_dims; i++) {
+ // stride is the number of indices over in the flattened tensor
+ // you need to skip in order to make it over to an adjacent element
+ // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
+ const int64 stride = dim_range[i] / dim_size[i];
+ const int shift = dim_size[i] - threshold[i];
+ const int indx = (start / stride) % dim_size[i];
+ indicies[i] = indx;
+ // calculate dimension index after the shift
+ int out_indx = (indx + shift) % dim_size[i];
+ if (i > isd) {
+ // trailing zeroes for indices after the inner shifted dimension
+ out_indx = 0;
+ remainder_offset += (out_indx - indx) * stride;
+ }
+ out_ptr += (out_indx - indx) * stride;
+ }
+ // set trailing zeroes for indices after the inner shifted dimension
+ for (int i = num_dims - 1; i > isd; i--) indicies[i] = 0;
+
+ // the number of indices in the isd dimension the next group will skip
+ // to make it to the next threshold or end point
+ int isd_indx_skip = 0;
+ // the size of the next group
+ int64 group_size = 0;
+ // initialize isd_indx_skip and group_size
+ if (indicies[isd] < threshold[isd]) {
+ isd_indx_skip = threshold[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride + remainder_offset;
+ } else {
+ isd_indx_skip = dim_size[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride + remainder_offset;
+ }
+
+ int64 i = start;
+ while (i < end) {
+ // copy group of elements
+ memcpy(out_ptr, in_ptr, group_size * sizeof(T));
+
+ // shift i and the pointers over to the next group position
+ i += group_size;
+ out_ptr += group_size;
+ in_ptr += group_size;
+
+ // produce next combination of indices and adjust the out_ptr position
+ // to fix the offset if necessary
+ // the isd (inner shift dim) should skip to next threshold or endpoint
+ // all dimensions to the left increment by 1 when a digit is carried
+ // all dimensions to the right remain set to 0
+ // +1 +1 +1 +isd_indx_skip
+ // indicies = [i, j, k, l, 0, 0]
+ // ^isd
+ for (int j = isd; j >= 0; j--) {
+ int inc = 1;
+ if (j == isd) inc = isd_indx_skip;
+ const int indx = (indicies[j] + inc) % dim_size[j];
+ indicies[j] = indx;
+ if (indx != 0) {
+ if (indx == threshold[j]) {
+ out_ptr -= dim_range[j]; // now wraps around
+ }
+ break; // indx != 0 don't need to carry
+ } else if (threshold[j] != 0) { // if threshold is 0 shift is 0
+ out_ptr += dim_range[j]; // indx became 0 so reverse wrap around
+ }
+ }
+
+ // set isd_indx_skip and group_size for next iteration
+ if (indicies[isd] < threshold[isd]) {
+ isd_indx_skip = threshold[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride;
+ } else {
+ isd_indx_skip = dim_size[isd] - indicies[isd];
+ group_size = isd_indx_skip * isd_stride;
+ }
+ }
+ };
+ // Shard
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ const int64 ave_group_size = dim_range[isd] / 2;
+ const int total_work = 2 * num_elements / std::max<int>(dim_range[isd], 1);
+ // 25000 - expiramentally determined with float and bool types
+ const int cost_per_group = 25000 * sizeof(T) * ave_group_size;
+ Shard(worker_threads->num_threads, worker_threads->workers, total_work,
+ cost_per_group, std::move(work));
+}
+
+template <typename Device, typename T, typename Tshift, typename Taxis>
+class RollOp : public OpKernel {
+ public:
+ explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& input = context->input(0);
+ const Tensor& shift = context->input(1);
+ const Tensor& axis = context->input(2);
+
+ auto shift_flat = shift.flat<Tshift>();
+ auto axis_flat = axis.flat<Taxis>();
+
+ OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
+ errors::InvalidArgument("input must be 1-D or higher"));
+ OP_REQUIRES(context, shift.shape().dims() <= 1,
+ errors::InvalidArgument(
+ "shift must be a scalar or a 1-D vector. Found: ",
+ shift.shape().DebugString()));
+ OP_REQUIRES(context, axis.shape().dims() <= 1,
+ errors::InvalidArgument(
+ "axis must be a scalar or a 1-D vector. Found: ",
+ axis.shape().DebugString()));
+ OP_REQUIRES(
+ context, shift.shape() == axis.shape(),
+ errors::InvalidArgument("shift and axis must have the same size"));
+ const int64 num_elements = input.NumElements();
+ const int num_shifts = static_cast<int>(shift_flat.size());
+ const int num_dims = input.dims();
+
+ // if there are any duplicate axes, shift_mod_sum will have the
+ // total modulo sum of shifts for each dimension
+ gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0);
+ for (int i = 0; i < num_shifts; i++) {
+ const int axis = axis_flat(i);
+ OP_REQUIRES(context, axis < num_dims,
+ errors::InvalidArgument("axis ", axis, " is out of range"));
+ const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
+ const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
+ // modulo that works with negatives: ((x % y) + y) % y
+ shift_mod_sum[axis] = (sum % ds + ds) % ds;
+ }
+ // the size of each dimension
+ gtl::InlinedVector<int, 4> dim_size(num_dims);
+ // threshold[i] is the index that the roll starts to wrap back to the front
+ gtl::InlinedVector<int, 4> threshold(num_dims);
+ // dim_range is the number of indices over in the flattened tensor
+ // you need to skip in order to make it over from one side of a dimension
+ // to the other. Used to make the shifts wrap around after a threshold.
+ gtl::InlinedVector<int64, 4> dim_range(num_dims);
+ int64 dim_size_prod = 1; // dimension size product
+ // inner shift dimension (inner most shifted dimension)
+ int64 isd = 0;
+ for (int i = num_dims - 1; i >= 0; i--) {
+ if (isd == 0 && shift_mod_sum[i] != 0) isd = i;
+ const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1);
+ dim_size[i] = ds;
+ threshold[i] = (ds - shift_mod_sum[i]) % ds;
+ dim_size_prod *= static_cast<int64>(input.dim_size(i));
+ dim_range[i] = dim_size_prod;
+ }
+
+ Tensor* output = NULL;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ auto input_flat = input.flat<T>().data();
+ auto output_flat = output->flat<T>().data();
+
+ if (std::is_same<Device, CPUDevice>::value) {
+ if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
+ // V2 copies memory in groups instead of element by element
+ DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
+ input_flat, output_flat, threshold, dim_range, isd);
+ } else {
+ // incase memcpy does not work for current data type
+ DoRoll<T>(context, num_elements, num_dims, dim_size, input_flat,
+ output_flat, threshold, dim_range);
+ }
+ }
+ }
+};
+
+// Register the CPU kernels.
+#define REGISTER_CPU(type) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tshift") \
+ .TypeConstraint<int32>("Taxis"), \
+ RollOp<CPUDevice, type, int32, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tshift") \
+ .TypeConstraint<int32>("Taxis"), \
+ RollOp<CPUDevice, type, int64, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tshift") \
+ .TypeConstraint<int64>("Taxis"), \
+ RollOp<CPUDevice, type, int32, int64>) \
+ REGISTER_KERNEL_BUILDER(Name("Roll") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tshift") \
+ .TypeConstraint<int64>("Taxis"), \
+ RollOp<CPUDevice, type, int64, int64>)
+
+TF_CALL_ALL_TYPES(REGISTER_CPU);
+#undef REGISTER_CPU
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/roll_op_test.cc b/tensorflow/core/kernels/roll_op_test.cc
new file mode 100644
index 0000000000..90b6f8d0f3
--- /dev/null
+++ b/tensorflow/core/kernels/roll_op_test.cc
@@ -0,0 +1,484 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <functional>
+#include <memory>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class RollOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType data_type, DataType index_type) {
+ TF_ASSERT_OK(NodeDefBuilder("myop", "Roll")
+ .Input(FakeInput(data_type))
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(index_type))
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(RollOpTest, ScalarIndices) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
+ test::FillValues<float>(&expected, {2, 3, 4, 0, 1});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ScalarIndices_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({5}), {"a", "b", "c", "d", "e"});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({5}));
+ test::FillValues<string>(&expected, {"c", "d", "e", "a", "b"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ScalarIndices_Complex) {
+ MakeOp(DT_COMPLEX64, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<std::complex<float>>(
+ TensorShape({5}), {std::complex<float>(0, 10), std::complex<float>(1, 11),
+ std::complex<float>(2, 12), std::complex<float>(3, 13),
+ std::complex<float>(4, 14)});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_COMPLEX64, TensorShape({5}));
+ test::FillValues<std::complex<float>>(
+ &expected, {std::complex<float>(2, 12), std::complex<float>(3, 13),
+ std::complex<float>(4, 14), std::complex<float>(0, 10),
+ std::complex<float>(1, 11)});
+ test::ExpectTensorEqual<std::complex<float>>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({3, 5}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int32>(TensorShape({2}), {2, -1});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({3, 5}));
+ test::FillValues<float>(&expected,
+ {6, 7, 8, 9, 5, 11, 12, 13, 14, 10, 1, 2, 3, 4, 0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({3, 5}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
+ "k", "l", "m", "n", "o"});
+ AddInputFromArray<int32>(TensorShape({2}), {2, -1});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({3, 5}));
+ test::FillValues<string>(&expected, {"g", "h", "i", "j", "f", "l", "m", "n",
+ "o", "k", "b", "c", "d", "e", "a"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ AddInputFromArray<int32>(TensorShape({3}), {1, -1, -1});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 3}));
+ test::FillValues<float>(&expected, {10, 11, 9, 7, 8, 6, 4, 5, 3, 1, 2, 0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(
+ TensorShape({2, 2, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ AddInputFromArray<int32>(TensorShape({3}), {1, -1, -1});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({2, 2, 3}));
+ test::FillValues<string>(
+ &expected, {"k", "l", "j", "h", "i", "g", "e", "f", "d", "b", "c", "a"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD64) {
+ MakeOp(DT_FLOAT, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int64>(TensorShape({2}), {-1, 4});
+ AddInputFromArray<int64>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
+ test::FillValues<float>(&expected,
+ {5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 2, 0, 1});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_TwoD64_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({5, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
+ "k", "l", "m", "n", "o"});
+ AddInputFromArray<int64>(TensorShape({2}), {-1, 4});
+ AddInputFromArray<int64>(TensorShape({2}), {0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({5, 3}));
+ test::FillValues<string>(&expected, {"f", "d", "e", "i", "g", "h", "l", "j",
+ "k", "o", "m", "n", "c", "a", "b"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD64) {
+ MakeOp(DT_FLOAT, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({4, 1, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ AddInputFromArray<int64>(TensorShape({3}), {4, 3, 2});
+ AddInputFromArray<int64>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 1, 3}));
+ test::FillValues<float>(&expected, {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Simple_ThreeD64_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<string>(
+ TensorShape({4, 1, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ AddInputFromArray<int64>(TensorShape({3}), {4, 3, 2});
+ AddInputFromArray<int64>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({4, 1, 3}));
+ test::FillValues<string>(
+ &expected, {"b", "c", "a", "e", "f", "d", "h", "i", "g", "k", "l", "j"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroShift_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 3}));
+ test::FillValues<float>(&expected, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroShift_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(
+ TensorShape({2, 2, 3}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 1, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({2, 2, 3}));
+ test::FillValues<string>(
+ &expected, {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroSize_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 0, 0}), {});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 0, 0}));
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, ZeroSize_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({5, 0, 0}), {});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({5, 0, 0}));
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, OneSize_ThreeD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({1, 1, 1}), {5});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1}));
+ test::FillValues<float>(&expected, {5});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, OneSize_ThreeD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({1, 1, 1}), {"a"});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({1, 1, 1}));
+ test::FillValues<string>(&expected, {"a"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, MultiShifts_TwoD32) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({3, 5}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int32>(TensorShape({4}), {-2, 2, -1, 1});
+ AddInputFromArray<int32>(TensorShape({4}), {1, 0, 0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({3, 5}));
+ test::FillValues<float>(&expected,
+ {11, 12, 13, 14, 10, 1, 2, 3, 4, 0, 6, 7, 8, 9, 5});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, MultiShifts_TwoD32_NoMemcpy) {
+ MakeOp(DT_STRING, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({3, 5}),
+ {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
+ "k", "l", "m", "n", "o"});
+ AddInputFromArray<int32>(TensorShape({4}), {-2, 2, -1, 1});
+ AddInputFromArray<int32>(TensorShape({4}), {1, 0, 0, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_STRING, TensorShape({3, 5}));
+ test::FillValues<string>(&expected, {"l", "m", "n", "o", "k", "b", "c", "d",
+ "e", "a", "g", "h", "i", "j", "f"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(RollOpTest, Error_InputMustBeVectorOrHigher) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({}), {7});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString()).contains("input must be 1-D or higher"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_AxisMustBeScalarOrVector) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({1, 2}), {0, 1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("axis must be a scalar or a 1-D vector"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_ShiftMustBeScalarOrVector) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({1, 2}), {0, 1});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("shift must be a scalar or a 1-D vector"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_ShiftAndAxisMustBeSameSize) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 2}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({1}), {1});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("shift and axis must have the same size"))
+ << s;
+}
+
+TEST_F(RollOpTest, Error_AxisOutOfRange) {
+ MakeOp(DT_FLOAT, DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({4}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString()).contains("is out of range")) << s;
+}
+
+// isd - (inner shift dimension) The inner most dimension to be shifted.
+// All outer dimensions will also be shifted for testing.
+static Graph* RollGraph(const TensorShape& shape, int isd) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor input(DT_FLOAT, shape);
+ input.flat<float>().setRandom();
+ const int dims = static_cast<int>(input.dims());
+ Tensor shift(DT_INT32, TensorShape({dims}));
+ for (int i = 0; i < dims; i++) {
+ // shift the inner shift dimension and all outer dimensions
+ shift.flat<int32>()(i) = (i <= isd) ? 2 : 0;
+ }
+ Tensor axis(DT_INT32, TensorShape({dims}));
+ for (int i = 0; i < dims; i++) {
+ axis.flat<int32>()(i) = i;
+ }
+ test::graph::Roll(g, test::graph::Constant(g, input),
+ test::graph::Constant(g, shift),
+ test::graph::Constant(g, axis));
+ return g;
+}
+
+#define BM_ROLL_OUTER(DEVICE) \
+ static void BM_##DEVICE##_roll_outer(int iters, int rows, int columns) { \
+ TensorShape shape{rows, columns}; \
+ const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
+ testing::ItemsProcessed(num_items); \
+ testing::BytesProcessed(num_items * sizeof(float)); \
+ testing::UseRealTime(); \
+ test::Benchmark(#DEVICE, RollGraph(shape, 0)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_roll_outer) \
+ ->ArgPair(256, 256) \
+ ->ArgPair(512, 512) \
+ ->ArgPair(1024, 1024) \
+ ->ArgPair(2048, 2048)
+
+#define BM_ROLL_ALL(DEVICE) \
+ static void BM_##DEVICE##_roll_all(int iters, int rows, int columns) { \
+ TensorShape shape{rows, columns}; \
+ const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
+ testing::ItemsProcessed(num_items); \
+ testing::BytesProcessed(num_items * sizeof(float)); \
+ testing::UseRealTime(); \
+ test::Benchmark(#DEVICE, RollGraph(shape, 1)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_roll_all) \
+ ->ArgPair(256, 256) \
+ ->ArgPair(512, 512) \
+ ->ArgPair(1024, 1024) \
+ ->ArgPair(2048, 2048)
+
+BM_ROLL_OUTER(cpu);
+BM_ROLL_ALL(cpu);
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc
new file mode 100644
index 0000000000..a61272675b
--- /dev/null
+++ b/tensorflow/core/kernels/unravel_index_op.cc
@@ -0,0 +1,122 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+
+namespace {
+template <typename T>
+struct mod_op {
+ const T operator()(const T& a, const T& b) const { return a % b; }
+};
+} // namespace
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Tidx>
+class UnravelIndexOp : public OpKernel {
+ public:
+ explicit UnravelIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& indices_tensor = ctx->input(0);
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsVector(indices_tensor.shape()) ||
+ TensorShapeUtils::IsScalar(indices_tensor.shape()),
+ errors::InvalidArgument(
+ "The indices can only be scalar or vector, got \"",
+ indices_tensor.shape().DebugString(), "\""));
+
+ const Tensor& dims_tensor = ctx->input(1);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(dims_tensor.shape()),
+ errors::InvalidArgument("The indices can only be 1-D, got \"",
+ dims_tensor.shape().DebugString(), "\""));
+
+ auto dims = dims_tensor.vec<Tidx>();
+
+ Eigen::array<bool, 1> reverse({true});
+
+ Tensor strides_tensor;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<Tidx>::value,
+ TensorShape({dims_tensor.NumElements()}),
+ &strides_tensor));
+
+ auto strides = strides_tensor.vec<Tidx>();
+ strides = dims.reverse(reverse)
+ .scan(0, Eigen::internal::ProdReducer<Tidx>(), false)
+ .reverse(reverse);
+
+ Tensor strides_shifted_tensor;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<Tidx>::value,
+ TensorShape({dims_tensor.NumElements()}),
+ &strides_shifted_tensor));
+
+ auto strides_shifted = strides_shifted_tensor.vec<Tidx>();
+ strides_shifted = dims.reverse(reverse)
+ .scan(0, Eigen::internal::ProdReducer<Tidx>(), true)
+ .reverse(reverse);
+
+ Tensor* output_tensor = nullptr;
+ if (TensorShapeUtils::IsScalar(indices_tensor.shape())) {
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, TensorShape({dims_tensor.NumElements()}),
+ &output_tensor));
+
+ auto output = output_tensor->vec<Tidx>();
+
+ output = output.constant(indices_tensor.scalar<Tidx>()());
+ output = output.binaryExpr(strides, mod_op<Tidx>()) / strides_shifted;
+ } else {
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0,
+ TensorShape({dims_tensor.NumElements(),
+ indices_tensor.NumElements()}),
+ &output_tensor));
+
+ auto output = output_tensor->matrix<Tidx>();
+
+ Eigen::array<int64, 2> reshape{{dims_tensor.NumElements(), 1}};
+ Eigen::array<int64, 2> bcast({1, indices_tensor.NumElements()});
+ Eigen::array<int64, 2> indices_reshape{{1, indices_tensor.NumElements()}};
+ Eigen::array<int64, 2> indices_bcast({dims_tensor.NumElements(), 1});
+
+ output = indices_tensor.vec<Tidx>()
+ .reshape(indices_reshape)
+ .broadcast(indices_bcast);
+ output = output.binaryExpr(strides.reshape(reshape).broadcast(bcast),
+ mod_op<Tidx>()) /
+ strides_shifted.reshape(reshape).broadcast(bcast);
+ }
+ }
+};
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("UnravelIndex").Device(DEVICE_CPU).TypeConstraint<type>("Tidx"), \
+ UnravelIndexOp<type>);
+TF_CALL_int32(REGISTER_KERNEL) TF_CALL_int64(REGISTER_KERNEL)
+#undef REGISTER_KERNEL
+
+} // namespace tensorflow