aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/math_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/math_ops.cc')
-rw-r--r--tensorflow/core/ops/math_ops.cc23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index c229bd5a41..386ae9635a 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1380,10 +1380,26 @@ REGISTER_OP("HistogramFixedWidth")
.Attr("T: {int32, int64, float32, float64}")
.Attr("dtype: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
+ // value_range should be a vector.
+ ShapeHandle value_range_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape));
+ // value_range should have two elements.
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(
+ c->WithValue(c->Dim(value_range_shape, 0), 2, &unused));
+ // nbins should be a scalar.
+ ShapeHandle nbins_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape));
+
+ // If nbins is available, set the shape from nbins.
const Tensor* nbins_input = c->input_tensor(2);
if (nbins_input != nullptr) {
int64 nbins;
TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins));
+ // nbins has to be positive.
+ if (nbins <= 0) {
+ return errors::InvalidArgument("Requires nbins > 0: ", nbins);
+ }
c->set_output(0, c->Vector(nbins));
} else {
c->set_output(0, c->UnknownShapeOfRank(1));
@@ -1488,6 +1504,13 @@ REGISTER_OP("QuantizedAdd")
.SetIsCommutative()
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
+ // min_x, max_x, min_y, max_y should be scalar.
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+
c->set_output(1, c->Scalar());
c->set_output(2, c->Scalar());
return Status::OK();