aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/argmax_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/argmax_op.cc')
-rw-r--r--tensorflow/core/kernels/argmax_op.cc163
1 files changed, 163 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/argmax_op.cc b/tensorflow/core/kernels/argmax_op.cc
new file mode 100644
index 0000000000..0845eebf09
--- /dev/null
+++ b/tensorflow/core/kernels/argmax_op.cc
@@ -0,0 +1,163 @@
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/argmax_op.h"
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T, typename ArgFunctor>
+class ArgOp : public OpKernel {
+ public:
+ explicit ArgOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& dimension = context->input(1);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(dimension.shape()),
+ errors::InvalidArgument(
+ "dim must be a scalar, but received tensor of shape: ",
+ dimension.shape().DebugString()));
+
+ const int32 dim = dimension.scalar<int32>()();
+ const int input_dims = input.dims();
+
+ OP_REQUIRES(context, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
+ OP_REQUIRES(context, dim < input_dims,
+ errors::InvalidArgument("Minimum tensor rank: ", dim,
+ " but got: ", input_dims));
+
+ TensorShape output_shape;
+ TensorShape input_shape = input.shape();
+ for (int d = 0; d < input_dims - 1; ++d) {
+ output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
+ }
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+
+#define HANDLE_DIM(NDIM) \
+ case NDIM: \
+ ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \
+ input.tensor<T, NDIM>(), dim, \
+ output->tensor<int64, NDIM - 1>()); \
+ break;
+
+ switch (input_dims) {
+ HANDLE_DIM(1);
+ HANDLE_DIM(2);
+ HANDLE_DIM(3);
+ HANDLE_DIM(4);
+ HANDLE_DIM(5);
+
+ default:
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument(
+ "ArgOp : Unhandled input dimensions: ", input_dims));
+ }
+ }
+#undef HANDLE_DIM
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
+};
+
+template <typename Device, typename T>
+class ArgMaxOp : public ArgOp<Device, T, functor::ArgMax<Device, T> > {
+ public:
+ explicit ArgMaxOp(OpKernelConstruction* context)
+ : ArgOp<Device, T, functor::ArgMax<Device, T> >(context) {}
+};
+
+template <typename Device, typename T>
+class ArgMinOp : public ArgOp<Device, T, functor::ArgMin<Device, T> > {
+ public:
+ explicit ArgMinOp(OpKernelConstruction* context)
+ : ArgOp<Device, T, functor::ArgMin<Device, T> >(context) {}
+};
+
+#define REGISTER_ARGMAX(type) \
+ REGISTER_KERNEL_BUILDER(Name("ArgMax") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("dimension"), \
+ ArgMaxOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("ArgMin") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("dimension"), \
+ ArgMinOp<CPUDevice, type>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX);
+
+#if GOOGLE_CUDA
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+
+#define DECLARE_GPU_SPEC(T, Dims) \
+ template <> \
+ void ArgMax<GPUDevice, T>::Reduce##Dims( \
+ const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
+ const int32 dimension, typename TTypes<int64, Dims - 1>::Tensor output); \
+ template <> \
+ void ArgMin<GPUDevice, T>::Reduce##Dims( \
+ const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
+ const int32 dimension, typename TTypes<int64, Dims - 1>::Tensor output);
+
+#define DECLARE_GPU_SPECS(T) \
+ DECLARE_GPU_SPEC(T, 1); \
+ DECLARE_GPU_SPEC(T, 2); \
+ DECLARE_GPU_SPEC(T, 3); \
+ DECLARE_GPU_SPEC(T, 4); \
+ DECLARE_GPU_SPEC(T, 5);
+
+#define DECLARE_GPU_CLASS(T) \
+ extern template struct ArgMax<GPUDevice, T>; \
+ extern template struct ArgMin<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS);
+
+#undef DECLARE_GPU_SPECS
+#undef DECLARE_GPU_CLASS
+
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_ARGMAX_GPU(type) \
+ REGISTER_KERNEL_BUILDER(Name("ArgMax") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("dimension"), \
+ ArgMaxOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("ArgMin") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("dimension"), \
+ ArgMinOp<GPUDevice, type>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU);
+
+#undef REGISTER_ARGMAX_GPU
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow