diff options
author | 2017-10-17 12:22:20 -0700 | |
---|---|---|
committer | 2017-10-17 12:22:20 -0700 | |
commit | 962ed613cf1087637848d3e2b23f5b01d93c7eda (patch) | |
tree | 39264e15cc6f7593e3b7181cbe752c012fb7821c | |
parent | 27767d8e9c1325979cf32ff5b81c10df9006fd57 (diff) |
Fix #13731 by adding HistogramdFixedWidth in hidden_ops.txt and create the python wrapper (#13781)
* Fix 13731 by adding HistogramdFixedWidth in hidden_ops.txt and create the python wrapper
so that both api compatibility and test utility code in contrib could be
preserved. See https://github.com/tensorflow/tensorflow/pull/13731#issuecomment-337186002
for reference.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add underscore (`_histogram_fixed_width`) in calling gen_math_ops.py
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* clang-format -i histogram_op.cc
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
-rw-r--r-- | tensorflow/core/kernels/histogram_op.cc | 24 | ||||
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/hidden_ops.txt | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/histogram_ops.py | 4 |
4 files changed, 16 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/histogram_op.cc b/tensorflow/core/kernels/histogram_op.cc index c170f172e4..4e035286f6 100644 --- a/tensorflow/core/kernels/histogram_op.cc +++ b/tensorflow/core/kernels/histogram_op.cc @@ -74,45 +74,44 @@ struct HistogramFixedWidthFunctor<CPUDevice, T, Tout> { template <typename Device, typename T, typename Tout> class HistogramFixedWidthOp : public OpKernel { public: - explicit HistogramFixedWidthOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("nbins", &nbins_)); - OP_REQUIRES( - ctx, (nbins_ > 0), - errors::InvalidArgument("nbins should be a positive number, but got '", - nbins_, "'")); - } + explicit HistogramFixedWidthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& values_tensor = ctx->input(0); const Tensor& value_range_tensor = ctx->input(1); + const Tensor& nbins_tensor = ctx->input(2); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(value_range_tensor.shape()), errors::InvalidArgument("value_range should be a vector.")); OP_REQUIRES(ctx, (value_range_tensor.shape().num_elements() == 2), errors::InvalidArgument( "value_range should be a vector of 2 elements.")); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(nbins_tensor.shape()), + errors::InvalidArgument("nbins should be a scalar.")); const auto values = values_tensor.flat<T>(); const auto value_range = value_range_tensor.flat<T>(); + const auto nbins = nbins_tensor.scalar<int32>()(); OP_REQUIRES( ctx, (value_range(0) < value_range(1)), errors::InvalidArgument("value_range should satisfy value_range[0] < " "value_range[1], but got '[", value_range(0), ", ", value_range(1), "]'")); + OP_REQUIRES( + ctx, (nbins > 0), + errors::InvalidArgument("nbins should be a positive number, but got '", + nbins, "'")); Tensor* out_tensor; OP_REQUIRES_OK(ctx, - ctx->allocate_output(0, TensorShape({nbins_}), &out_tensor)); + ctx->allocate_output(0, TensorShape({nbins}), &out_tensor)); auto out = out_tensor->flat<Tout>(); OP_REQUIRES_OK( ctx, functor::HistogramFixedWidthFunctor<Device, T, Tout>::Compute( - ctx, values, value_range, nbins_, out)); + ctx, values, value_range, nbins, out)); } - - private: - int nbins_; }; #define REGISTER_KERNELS(type) \ @@ -135,6 +134,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \ .Device(DEVICE_GPU) \ .HostMemory("value_range") \ + .HostMemory("nbins") \ .TypeConstraint<type>("T") \ .TypeConstraint<int32>("dtype"), \ HistogramFixedWidthOp<GPUDevice, type, int32>) diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index a1c608ee54..61db896c51 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -2253,8 +2253,8 @@ product: Pairwise cross product of the vectors in `a` and `b`. REGISTER_OP("HistogramFixedWidth") .Input("values: T") .Input("value_range: T") + .Input("nbins: int32") .Output("out: dtype") - .Attr("nbins: int = 100") .Attr("T: {int32, int64, float32, float64}") .Attr("dtype: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index 04dfb5b65d..86bc038e86 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -259,6 +259,7 @@ ComplexAbs Conj FloorDiv FloorMod +HistogramFixedWidth Max Mean Min diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py index 040c3a5ae8..51e4be9343 100644 --- a/tensorflow/python/ops/histogram_ops.py +++ b/tensorflow/python/ops/histogram_ops.py @@ -71,5 +71,5 @@ def histogram_fixed_width(values, """ with ops.name_scope(name, 'histogram_fixed_width', [values, value_range, nbins]) as name: - return gen_math_ops.histogram_fixed_width(values, value_range, nbins, - dtype=dtype, name=name) + return gen_math_ops._histogram_fixed_width(values, value_range, nbins, + dtype=dtype, name=name) |