diff options
Diffstat (limited to 'tensorflow/core/kernels/histogram_op.cc')
-rw-r--r-- | tensorflow/core/kernels/histogram_op.cc | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/histogram_op.cc b/tensorflow/core/kernels/histogram_op.cc new file mode 100644 index 0000000000..4e035286f6 --- /dev/null +++ b/tensorflow/core/kernels/histogram_op.cc @@ -0,0 +1,147 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/histogram_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template <typename T, typename Tout> +struct HistogramFixedWidthFunctor<CPUDevice, T, Tout> { + static Status Compute(OpKernelContext* context, + const typename TTypes<T, 1>::ConstTensor& values, + const typename TTypes<T, 1>::ConstTensor& value_range, + int32 nbins, typename TTypes<Tout, 1>::Tensor& out) { + const CPUDevice& d = context->eigen_device<CPUDevice>(); + + Tensor index_to_bin_tensor; + + TF_RETURN_IF_ERROR(context->forward_input_or_allocate_temp( + {0}, DataTypeToEnum<int32>::value, TensorShape({values.size()}), + &index_to_bin_tensor)); + auto index_to_bin = index_to_bin_tensor.flat<int32>(); + + const double step = static_cast<double>(value_range(1) - value_range(0)) / + static_cast<double>(nbins); + + // The calculation is done by finding the slot of each value in `values`. + // With [a, b]: + // step = (b - a) / nbins + // (x - a) / step + // , then the entries are mapped to output. + index_to_bin.device(d) = + ((values.cwiseMax(value_range(0)) - values.constant(value_range(0))) + .template cast<double>() / + step) + .template cast<int32>() + .cwiseMin(nbins - 1); + + out.setZero(); + for (int32 i = 0; i < index_to_bin.size(); i++) { + out(index_to_bin(i)) += Tout(1); + } + return Status::OK(); + } +}; + +} // namespace functor + +template <typename Device, typename T, typename Tout> +class HistogramFixedWidthOp : public OpKernel { + public: + explicit HistogramFixedWidthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& values_tensor = ctx->input(0); + const Tensor& value_range_tensor = ctx->input(1); + const Tensor& nbins_tensor = ctx->input(2); + + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(value_range_tensor.shape()), + errors::InvalidArgument("value_range should be a vector.")); + OP_REQUIRES(ctx, (value_range_tensor.shape().num_elements() == 2), + errors::InvalidArgument( + "value_range should be a vector of 2 elements.")); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(nbins_tensor.shape()), + errors::InvalidArgument("nbins should be a scalar.")); + + const auto values = values_tensor.flat<T>(); + const auto value_range = value_range_tensor.flat<T>(); + const auto nbins = nbins_tensor.scalar<int32>()(); + + OP_REQUIRES( + ctx, (value_range(0) < value_range(1)), + errors::InvalidArgument("value_range should satisfy value_range[0] < " + "value_range[1], but got '[", + value_range(0), ", ", value_range(1), "]'")); + OP_REQUIRES( + ctx, (nbins > 0), + errors::InvalidArgument("nbins should be a positive number, but got '", + nbins, "'")); + + Tensor* out_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({nbins}), &out_tensor)); + auto out = out_tensor->flat<Tout>(); + + OP_REQUIRES_OK( + ctx, functor::HistogramFixedWidthFunctor<Device, T, Tout>::Compute( + ctx, values, value_range, nbins, out)); + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("dtype"), \ + HistogramFixedWidthOp<CPUDevice, type, int32>) \ + REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("dtype"), \ + HistogramFixedWidthOp<CPUDevice, type, int64>) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#if GOOGLE_CUDA +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \ + .Device(DEVICE_GPU) \ + .HostMemory("value_range") \ + .HostMemory("nbins") \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("dtype"), \ + HistogramFixedWidthOp<GPUDevice, type, int32>) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#endif // GOOGLE_CUDA + +} // end namespace tensorflow |