aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2017-10-17 12:22:20 -0700
committerGravatar Vijay Vasudevan <vrv@google.com>2017-10-17 12:22:20 -0700
commit962ed613cf1087637848d3e2b23f5b01d93c7eda (patch)
tree39264e15cc6f7593e3b7181cbe752c012fb7821c
parent27767d8e9c1325979cf32ff5b81c10df9006fd57 (diff)
Fix #13731 by adding HistogramdFixedWidth in hidden_ops.txt and create the python wrapper (#13781)
* Fix 13731 by adding HistogramdFixedWidth in hidden_ops.txt and create the python wrapper so that both api compatibility and test utility code in contrib could be preserved. See https://github.com/tensorflow/tensorflow/pull/13731#issuecomment-337186002 for reference. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add underscore (`_histogram_fixed_width`) in calling gen_math_ops.py Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * clang-format -i histogram_op.cc Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
-rw-r--r--tensorflow/core/kernels/histogram_op.cc24
-rw-r--r--tensorflow/core/ops/math_ops.cc2
-rw-r--r--tensorflow/python/ops/hidden_ops.txt1
-rw-r--r--tensorflow/python/ops/histogram_ops.py4
4 files changed, 16 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/histogram_op.cc b/tensorflow/core/kernels/histogram_op.cc
index c170f172e4..4e035286f6 100644
--- a/tensorflow/core/kernels/histogram_op.cc
+++ b/tensorflow/core/kernels/histogram_op.cc
@@ -74,45 +74,44 @@ struct HistogramFixedWidthFunctor<CPUDevice, T, Tout> {
template <typename Device, typename T, typename Tout>
class HistogramFixedWidthOp : public OpKernel {
public:
- explicit HistogramFixedWidthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("nbins", &nbins_));
- OP_REQUIRES(
- ctx, (nbins_ > 0),
- errors::InvalidArgument("nbins should be a positive number, but got '",
- nbins_, "'"));
- }
+ explicit HistogramFixedWidthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& values_tensor = ctx->input(0);
const Tensor& value_range_tensor = ctx->input(1);
+ const Tensor& nbins_tensor = ctx->input(2);
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(value_range_tensor.shape()),
errors::InvalidArgument("value_range should be a vector."));
OP_REQUIRES(ctx, (value_range_tensor.shape().num_elements() == 2),
errors::InvalidArgument(
"value_range should be a vector of 2 elements."));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(nbins_tensor.shape()),
+ errors::InvalidArgument("nbins should be a scalar."));
const auto values = values_tensor.flat<T>();
const auto value_range = value_range_tensor.flat<T>();
+ const auto nbins = nbins_tensor.scalar<int32>()();
OP_REQUIRES(
ctx, (value_range(0) < value_range(1)),
errors::InvalidArgument("value_range should satisfy value_range[0] < "
"value_range[1], but got '[",
value_range(0), ", ", value_range(1), "]'"));
+ OP_REQUIRES(
+ ctx, (nbins > 0),
+ errors::InvalidArgument("nbins should be a positive number, but got '",
+ nbins, "'"));
Tensor* out_tensor;
OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, TensorShape({nbins_}), &out_tensor));
+ ctx->allocate_output(0, TensorShape({nbins}), &out_tensor));
auto out = out_tensor->flat<Tout>();
OP_REQUIRES_OK(
ctx, functor::HistogramFixedWidthFunctor<Device, T, Tout>::Compute(
- ctx, values, value_range, nbins_, out));
+ ctx, values, value_range, nbins, out));
}
-
- private:
- int nbins_;
};
#define REGISTER_KERNELS(type) \
@@ -135,6 +134,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \
.Device(DEVICE_GPU) \
.HostMemory("value_range") \
+ .HostMemory("nbins") \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("dtype"), \
HistogramFixedWidthOp<GPUDevice, type, int32>)
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index a1c608ee54..61db896c51 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -2253,8 +2253,8 @@ product: Pairwise cross product of the vectors in `a` and `b`.
REGISTER_OP("HistogramFixedWidth")
.Input("values: T")
.Input("value_range: T")
+ .Input("nbins: int32")
.Output("out: dtype")
- .Attr("nbins: int = 100")
.Attr("T: {int32, int64, float32, float64}")
.Attr("dtype: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 04dfb5b65d..86bc038e86 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -259,6 +259,7 @@ ComplexAbs
Conj
FloorDiv
FloorMod
+HistogramFixedWidth
Max
Mean
Min
diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py
index 040c3a5ae8..51e4be9343 100644
--- a/tensorflow/python/ops/histogram_ops.py
+++ b/tensorflow/python/ops/histogram_ops.py
@@ -71,5 +71,5 @@ def histogram_fixed_width(values,
"""
with ops.name_scope(name, 'histogram_fixed_width',
[values, value_range, nbins]) as name:
- return gen_math_ops.histogram_fixed_width(values, value_range, nbins,
- dtype=dtype, name=name)
+ return gen_math_ops._histogram_fixed_width(values, value_range, nbins,
+ dtype=dtype, name=name)