diff options
Diffstat (limited to 'tensorflow/core/ops/math_ops.cc')
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index c229bd5a41..386ae9635a 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1380,10 +1380,26 @@ REGISTER_OP("HistogramFixedWidth") .Attr("T: {int32, int64, float32, float64}") .Attr("dtype: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { + // value_range should be a vector. + ShapeHandle value_range_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape)); + // value_range should have two elements. + DimensionHandle unused; + TF_RETURN_IF_ERROR( + c->WithValue(c->Dim(value_range_shape, 0), 2, &unused)); + // nbins should be a scalar. + ShapeHandle nbins_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape)); + + // If nbins is available, set the shape from nbins. const Tensor* nbins_input = c->input_tensor(2); if (nbins_input != nullptr) { int64 nbins; TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins)); + // nbins has to be positive. + if (nbins <= 0) { + return errors::InvalidArgument("Requires nbins > 0: ", nbins); + } c->set_output(0, c->Vector(nbins)); } else { c->set_output(0, c->UnknownShapeOfRank(1)); @@ -1488,6 +1504,13 @@ REGISTER_OP("QuantizedAdd") .SetIsCommutative() .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); + // min_x, max_x, min_y, max_y should be scalar. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + c->set_output(1, c->Scalar()); c->set_output(2, c->Scalar()); return Status::OK(); |