aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/unravel_index_op.cc
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-02-07 14:36:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 14:39:49 -0800
commitd90054e7c0f41f4bab81df0548577a73b939a87a (patch)
treea15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/core/kernels/unravel_index_op.cc
parent8461760f9f6cde8ed97507484d2a879140141032 (diff)
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/core/kernels/unravel_index_op.cc')
-rw-r--r--tensorflow/core/kernels/unravel_index_op.cc122
1 files changed, 122 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc
new file mode 100644
index 0000000000..a61272675b
--- /dev/null
+++ b/tensorflow/core/kernels/unravel_index_op.cc
@@ -0,0 +1,122 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+
+namespace {
+template <typename T>
+struct mod_op {
+ const T operator()(const T& a, const T& b) const { return a % b; }
+};
+} // namespace
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Tidx>
+class UnravelIndexOp : public OpKernel {
+ public:
+ explicit UnravelIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& indices_tensor = ctx->input(0);
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsVector(indices_tensor.shape()) ||
+ TensorShapeUtils::IsScalar(indices_tensor.shape()),
+ errors::InvalidArgument(
+ "The indices can only be scalar or vector, got \"",
+ indices_tensor.shape().DebugString(), "\""));
+
+ const Tensor& dims_tensor = ctx->input(1);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(dims_tensor.shape()),
+ errors::InvalidArgument("The indices can only be 1-D, got \"",
+ dims_tensor.shape().DebugString(), "\""));
+
+ auto dims = dims_tensor.vec<Tidx>();
+
+ Eigen::array<bool, 1> reverse({true});
+
+ Tensor strides_tensor;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<Tidx>::value,
+ TensorShape({dims_tensor.NumElements()}),
+ &strides_tensor));
+
+ auto strides = strides_tensor.vec<Tidx>();
+ strides = dims.reverse(reverse)
+ .scan(0, Eigen::internal::ProdReducer<Tidx>(), false)
+ .reverse(reverse);
+
+ Tensor strides_shifted_tensor;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<Tidx>::value,
+ TensorShape({dims_tensor.NumElements()}),
+ &strides_shifted_tensor));
+
+ auto strides_shifted = strides_shifted_tensor.vec<Tidx>();
+ strides_shifted = dims.reverse(reverse)
+ .scan(0, Eigen::internal::ProdReducer<Tidx>(), true)
+ .reverse(reverse);
+
+ Tensor* output_tensor = nullptr;
+ if (TensorShapeUtils::IsScalar(indices_tensor.shape())) {
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, TensorShape({dims_tensor.NumElements()}),
+ &output_tensor));
+
+ auto output = output_tensor->vec<Tidx>();
+
+ output = output.constant(indices_tensor.scalar<Tidx>()());
+ output = output.binaryExpr(strides, mod_op<Tidx>()) / strides_shifted;
+ } else {
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0,
+ TensorShape({dims_tensor.NumElements(),
+ indices_tensor.NumElements()}),
+ &output_tensor));
+
+ auto output = output_tensor->matrix<Tidx>();
+
+ Eigen::array<int64, 2> reshape{{dims_tensor.NumElements(), 1}};
+ Eigen::array<int64, 2> bcast({1, indices_tensor.NumElements()});
+ Eigen::array<int64, 2> indices_reshape{{1, indices_tensor.NumElements()}};
+ Eigen::array<int64, 2> indices_bcast({dims_tensor.NumElements(), 1});
+
+ output = indices_tensor.vec<Tidx>()
+ .reshape(indices_reshape)
+ .broadcast(indices_bcast);
+ output = output.binaryExpr(strides.reshape(reshape).broadcast(bcast),
+ mod_op<Tidx>()) /
+ strides_shifted.reshape(reshape).broadcast(bcast);
+ }
+ }
+};
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("UnravelIndex").Device(DEVICE_CPU).TypeConstraint<type>("Tidx"), \
+ UnravelIndexOp<type>);
+TF_CALL_int32(REGISTER_KERNEL) TF_CALL_int64(REGISTER_KERNEL)
+#undef REGISTER_KERNEL
+
+} // namespace tensorflow