aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-04 15:15:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-04 16:31:16 -0700
commitee9241825d80bf295963ac2fad4dfa0fc9a7b998 (patch)
tree81260bea9c5328bd7c12fc4729c646e332115fb6 /tensorflow/core/framework/shape_inference.cc
parent21038467d71be31193715f7b023e252c0c5e2b05 (diff)
Add C++ shape inference for SVD.
This also adds Min(), Max(), and Subtract() operators and a few convenience methods to the InferenceContext. Change test utils to emit a human readable error message in case the user forgot to set the inference function. Refactored shape_inference* a bit to enforce the invariant that a Dimension or DimensionOrConstant is always non-negative or equal to InferenceContext::kUnknownDim. This made it possible to tighten & simplify the arithmetic operations a bit. Change: 129385995
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc163
1 files changed, 88 insertions, 75 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index e44d921d5d..9c90bfe0f5 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -387,12 +387,6 @@ Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in,
return ReturnCreatedShape(dims, out);
}
-const Dimension* InferenceContext::GetDimension(const DimensionOrConstant& d) {
- if (d.dim != nullptr) return d.dim;
- DCHECK(d.val >= 0 || d.val == kUnknownDim);
- return MakeDim(d.val);
-}
-
const Shape* InferenceContext::MakeShape(
const std::vector<const Dimension*>& dims) {
all_shapes_.push_back(new Shape(dims));
@@ -404,7 +398,7 @@ const Shape* InferenceContext::MakeShape(
std::vector<const Dimension*> dims_actual;
dims_actual.reserve(dims.size());
for (const DimensionOrConstant& d : dims) {
- dims_actual.push_back(GetDimension(d));
+ dims_actual.push_back(MakeDim(d));
}
return MakeShape(dims_actual);
}
@@ -488,11 +482,6 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
return ReturnCreatedShape(dims, out);
}
-const Dimension* InferenceContext::MakeDim(int64 value) {
- all_dims_.push_back(new Dimension(value));
- return all_dims_.back();
-}
-
// Returns a new dimension whose value is given by a scalar input tensor.
Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
const Tensor* t = input_tensor(idx);
@@ -522,11 +511,6 @@ Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
return Status::OK();
}
-const Dimension* InferenceContext::UnknownDim() {
- all_dims_.push_back(new Dimension());
- return all_dims_.back();
-}
-
Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
const Dimension** out) {
if (divisor == 1) {
@@ -535,6 +519,10 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
*out = UnknownDim();
} else {
const int64 v = Value(dividend);
+ if (divisor <= 0) {
+ return errors::InvalidArgument("Divisor must be positive but is ",
+ divisor);
+ }
if ((v % divisor) != 0) {
return errors::InvalidArgument("Dimension size must be divisible by ",
divisor, " but is ", v);
@@ -546,87 +534,112 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second,
const Dimension** out) {
- const int64 second_value =
- second.dim == nullptr ? second.val : Value(second.dim);
- if (second.dim != nullptr && !ValueKnown(second.dim)) {
- *out = UnknownDim();
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ // Special cases.
+ if (first_value == 0) {
+ *out = MakeDim(second);
} else if (second_value == 0) {
- *out = first;
- } else if (!ValueKnown(first)) {
+ *out = MakeDim(first);
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim();
} else {
- const int64 v = Value(first);
- const int64 sum = v + second_value;
- if (second_value > 0 && sum < 0) {
- return errors::InvalidArgument("Dimension size overflow from adding ", v,
- " and ", second_value);
- } else if (second_value < 0 && sum < 0) {
- return errors::InvalidArgument("Negative dimension size from adding ", v,
- " and ", second_value);
+ // Invariant: Both values are known and positive.
+ const int64 sum = first_value + second_value;
+ if (sum < 0) {
+ return errors::InvalidArgument("Dimension size overflow from adding ",
+ first_value, " and ", second_value);
}
*out = MakeDim(sum);
}
return Status::OK();
}
-Status InferenceContext::Multiply(const Dimension* first,
+Status InferenceContext::Subtract(const Dimension* first,
DimensionOrConstant second,
const Dimension** out) {
- int64 first_value = -1;
- // Special cases for multiply are when the values are 0 or 1.
- if (ValueKnown(first)) {
- first_value = Value(first);
- if (first_value == 0) {
- *out = MakeDim(0);
- return Status::OK();
- }
-
- // Output is whatever the second value is.
- if (first_value == 1) {
- *out = GetDimension(second);
- return Status::OK();
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ // Special cases.
+ if (second_value == 0) {
+ *out = MakeDim(first);
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
+ *out = UnknownDim();
+ } else {
+ // Invariant: Both values are known, first_value is non-negative, and
+ // second_value is positive.
+ if (first_value < second_value) {
+ return errors::InvalidArgument(
+ "Negative dimension size caused by subtracting ", second_value,
+ " from ", first_value);
}
+ *out = MakeDim(first_value - second_value);
}
+ return Status::OK();
+}
- // Same check for when the second argument is a known value.
- // First find out if the value is known from DimOrConstant.
- int64 second_value;
- if (second.dim == nullptr) {
- second_value = second.val;
+Status InferenceContext::Multiply(const Dimension* first,
+ DimensionOrConstant second,
+ const Dimension** out) {
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ // Special cases.
+ if (first_value == 0) {
+ *out = first;
+ } else if (second_value == 0) {
+ *out = MakeDim(second);
+ } else if (first_value == 1) {
+ *out = MakeDim(second);
+ } else if (second_value == 1) {
+ *out = first;
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
+ *out = UnknownDim();
} else {
- if (!ValueKnown(second.dim)) {
- // Second value is not known and first is not a special caase
- *out = UnknownDim();
- return Status::OK();
+ // Invariant: Both values are known and and greater than 1.
+ const int64 product = first_value * second_value;
+ if (product < 0) {
+ return errors::InvalidArgument(
+ "Negative dimension size caused by overflow when multiplying ",
+ first_value, " and ", second_value);
}
- second_value = Value(second.dim);
- }
-
- // Now that we know whether the value is known, apply the special
- // casing.
- if (second_value == 0) {
- *out = MakeDim(0);
- return Status::OK();
+ *out = MakeDim(product);
}
+ return Status::OK();
+}
- // Output is whatever the first value is.
- if (second_value == 1) {
+Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out) {
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ if (first_value == 0) {
*out = first;
- return Status::OK();
- }
-
- if (!ValueKnown(first)) {
- // First value is not known and second is not a special caase
+ } else if (second_value == 0) {
+ *out = MakeDim(second);
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim();
- return Status::OK();
+ } else {
+ if (first_value <= second_value) {
+ *out = first;
+ } else {
+ *out = MakeDim(second);
+ }
}
+ return Status::OK();
+}
- const int64 product = first_value * second_value;
- if (product < 0) {
- return errors::InvalidArgument("Negative dimension size from multiplying ",
- first_value, " and ", second_value);
+Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out) {
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ if (first_value == kUnknownDim || second_value == kUnknownDim) {
+ *out = UnknownDim();
+ } else {
+ if (first_value >= second_value) {
+ *out = first;
+ } else {
+ *out = MakeDim(second);
+ }
}
- *out = MakeDim(product);
return Status::OK();
}