diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/kernels/argmax_op.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/kernels/argmax_op.cc')
-rw-r--r-- | tensorflow/core/kernels/argmax_op.cc | 163 |
1 files changed, 163 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/argmax_op.cc b/tensorflow/core/kernels/argmax_op.cc new file mode 100644 index 0000000000..0845eebf09 --- /dev/null +++ b/tensorflow/core/kernels/argmax_op.cc @@ -0,0 +1,163 @@ +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow/core/kernels/argmax_op.h" + +#include <memory> +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template <typename Device, typename T, typename ArgFunctor> +class ArgOp : public OpKernel { + public: + explicit ArgOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& dimension = context->input(1); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(dimension.shape()), + errors::InvalidArgument( + "dim must be a scalar, but received tensor of shape: ", + dimension.shape().DebugString())); + + const int32 dim = dimension.scalar<int32>()(); + const int input_dims = input.dims(); + + OP_REQUIRES(context, dim >= 0, errors::InvalidArgument("dim must be >= 0")); + OP_REQUIRES(context, dim < input_dims, + errors::InvalidArgument("Minimum tensor rank: ", dim, + " but got: ", input_dims)); + + TensorShape output_shape; + TensorShape input_shape = input.shape(); + for (int d = 0; d < input_dims - 1; ++d) { + output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1)); + } + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + +#define HANDLE_DIM(NDIM) \ + case NDIM: \ + ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \ + input.tensor<T, NDIM>(), dim, \ + output->tensor<int64, NDIM - 1>()); \ + break; + + switch (input_dims) { + HANDLE_DIM(1); + HANDLE_DIM(2); + HANDLE_DIM(3); + HANDLE_DIM(4); + HANDLE_DIM(5); + + default: + OP_REQUIRES(context, false, + errors::InvalidArgument( + "ArgOp : Unhandled input dimensions: ", input_dims)); + } + } +#undef HANDLE_DIM + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); +}; + +template <typename Device, typename T> +class ArgMaxOp : public ArgOp<Device, T, functor::ArgMax<Device, T> > { + public: + explicit ArgMaxOp(OpKernelConstruction* context) + : ArgOp<Device, T, functor::ArgMax<Device, T> >(context) {} +}; + +template <typename Device, typename T> +class ArgMinOp : public ArgOp<Device, T, functor::ArgMin<Device, T> > { + public: + explicit ArgMinOp(OpKernelConstruction* context) + : ArgOp<Device, T, functor::ArgMin<Device, T> >(context) {} +}; + +#define REGISTER_ARGMAX(type) \ + REGISTER_KERNEL_BUILDER(Name("ArgMax") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("dimension"), \ + ArgMaxOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("ArgMin") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("dimension"), \ + ArgMinOp<CPUDevice, type>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX); + +#if GOOGLE_CUDA + +// Forward declarations of the functor specializations for GPU. +namespace functor { + +#define DECLARE_GPU_SPEC(T, Dims) \ + template <> \ + void ArgMax<GPUDevice, T>::Reduce##Dims( \ + const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ + const int32 dimension, typename TTypes<int64, Dims - 1>::Tensor output); \ + template <> \ + void ArgMin<GPUDevice, T>::Reduce##Dims( \ + const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ + const int32 dimension, typename TTypes<int64, Dims - 1>::Tensor output); + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPEC(T, 1); \ + DECLARE_GPU_SPEC(T, 2); \ + DECLARE_GPU_SPEC(T, 3); \ + DECLARE_GPU_SPEC(T, 4); \ + DECLARE_GPU_SPEC(T, 5); + +#define DECLARE_GPU_CLASS(T) \ + extern template struct ArgMax<GPUDevice, T>; \ + extern template struct ArgMin<GPUDevice, T>; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS); + +#undef DECLARE_GPU_SPECS +#undef DECLARE_GPU_CLASS + +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_ARGMAX_GPU(type) \ + REGISTER_KERNEL_BUILDER(Name("ArgMax") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("dimension"), \ + ArgMaxOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("ArgMin") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("dimension"), \ + ArgMinOp<GPUDevice, type>); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU); + +#undef REGISTER_ARGMAX_GPU + +#endif // GOOGLE_CUDA + +} // namespace tensorflow |