diff options
author | 2016-08-04 15:15:57 -0800 | |
---|---|---|
committer | 2016-08-04 16:31:16 -0700 | |
commit | ee9241825d80bf295963ac2fad4dfa0fc9a7b998 (patch) | |
tree | 81260bea9c5328bd7c12fc4729c646e332115fb6 /tensorflow/core/framework/shape_inference.cc | |
parent | 21038467d71be31193715f7b023e252c0c5e2b05 (diff) |
Add C++ shape inference for SVD.
This also adds Min(), Max(), and Subtract() operators and a few convenience methods to the InferenceContext. Change test utils to emit a human readable error message in case the user forgot to set the inference function.
Refactored shape_inference* a bit to enforce the invariant that a Dimension or DimensionOrConstant is always non-negative or equal to InferenceContext::kUnknownDim.
This made it possible to tighten & simplify the arithmetic operations a bit.
Change: 129385995
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 163 |
1 files changed, 88 insertions, 75 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index e44d921d5d..9c90bfe0f5 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -387,12 +387,6 @@ Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in, return ReturnCreatedShape(dims, out); } -const Dimension* InferenceContext::GetDimension(const DimensionOrConstant& d) { - if (d.dim != nullptr) return d.dim; - DCHECK(d.val >= 0 || d.val == kUnknownDim); - return MakeDim(d.val); -} - const Shape* InferenceContext::MakeShape( const std::vector<const Dimension*>& dims) { all_shapes_.push_back(new Shape(dims)); @@ -404,7 +398,7 @@ const Shape* InferenceContext::MakeShape( std::vector<const Dimension*> dims_actual; dims_actual.reserve(dims.size()); for (const DimensionOrConstant& d : dims) { - dims_actual.push_back(GetDimension(d)); + dims_actual.push_back(MakeDim(d)); } return MakeShape(dims_actual); } @@ -488,11 +482,6 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, return ReturnCreatedShape(dims, out); } -const Dimension* InferenceContext::MakeDim(int64 value) { - all_dims_.push_back(new Dimension(value)); - return all_dims_.back(); -} - // Returns a new dimension whose value is given by a scalar input tensor. Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) { const Tensor* t = input_tensor(idx); @@ -522,11 +511,6 @@ Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) { return Status::OK(); } -const Dimension* InferenceContext::UnknownDim() { - all_dims_.push_back(new Dimension()); - return all_dims_.back(); -} - Status InferenceContext::Divide(const Dimension* dividend, int64 divisor, const Dimension** out) { if (divisor == 1) { @@ -535,6 +519,10 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor, *out = UnknownDim(); } else { const int64 v = Value(dividend); + if (divisor <= 0) { + return errors::InvalidArgument("Divisor must be positive but is ", + divisor); + } if ((v % divisor) != 0) { return errors::InvalidArgument("Dimension size must be divisible by ", divisor, " but is ", v); @@ -546,87 +534,112 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor, Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second, const Dimension** out) { - const int64 second_value = - second.dim == nullptr ? second.val : Value(second.dim); - if (second.dim != nullptr && !ValueKnown(second.dim)) { - *out = UnknownDim(); + const int64 first_value = Value(first); + const int64 second_value = Value(second); + // Special cases. + if (first_value == 0) { + *out = MakeDim(second); } else if (second_value == 0) { - *out = first; - } else if (!ValueKnown(first)) { + *out = MakeDim(first); + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { *out = UnknownDim(); } else { - const int64 v = Value(first); - const int64 sum = v + second_value; - if (second_value > 0 && sum < 0) { - return errors::InvalidArgument("Dimension size overflow from adding ", v, - " and ", second_value); - } else if (second_value < 0 && sum < 0) { - return errors::InvalidArgument("Negative dimension size from adding ", v, - " and ", second_value); + // Invariant: Both values are known and positive. + const int64 sum = first_value + second_value; + if (sum < 0) { + return errors::InvalidArgument("Dimension size overflow from adding ", + first_value, " and ", second_value); } *out = MakeDim(sum); } return Status::OK(); } -Status InferenceContext::Multiply(const Dimension* first, +Status InferenceContext::Subtract(const Dimension* first, DimensionOrConstant second, const Dimension** out) { - int64 first_value = -1; - // Special cases for multiply are when the values are 0 or 1. - if (ValueKnown(first)) { - first_value = Value(first); - if (first_value == 0) { - *out = MakeDim(0); - return Status::OK(); - } - - // Output is whatever the second value is. - if (first_value == 1) { - *out = GetDimension(second); - return Status::OK(); + const int64 first_value = Value(first); + const int64 second_value = Value(second); + // Special cases. + if (second_value == 0) { + *out = MakeDim(first); + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); + } else { + // Invariant: Both values are known, first_value is non-negative, and + // second_value is positive. + if (first_value < second_value) { + return errors::InvalidArgument( + "Negative dimension size caused by subtracting ", second_value, + " from ", first_value); } + *out = MakeDim(first_value - second_value); } + return Status::OK(); +} - // Same check for when the second argument is a known value. - // First find out if the value is known from DimOrConstant. - int64 second_value; - if (second.dim == nullptr) { - second_value = second.val; +Status InferenceContext::Multiply(const Dimension* first, + DimensionOrConstant second, + const Dimension** out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + // Special cases. + if (first_value == 0) { + *out = first; + } else if (second_value == 0) { + *out = MakeDim(second); + } else if (first_value == 1) { + *out = MakeDim(second); + } else if (second_value == 1) { + *out = first; + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); } else { - if (!ValueKnown(second.dim)) { - // Second value is not known and first is not a special caase - *out = UnknownDim(); - return Status::OK(); + // Invariant: Both values are known and and greater than 1. + const int64 product = first_value * second_value; + if (product < 0) { + return errors::InvalidArgument( + "Negative dimension size caused by overflow when multiplying ", + first_value, " and ", second_value); } - second_value = Value(second.dim); - } - - // Now that we know whether the value is known, apply the special - // casing. - if (second_value == 0) { - *out = MakeDim(0); - return Status::OK(); + *out = MakeDim(product); } + return Status::OK(); +} - // Output is whatever the first value is. - if (second_value == 1) { +Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second, + const Dimension** out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + if (first_value == 0) { *out = first; - return Status::OK(); - } - - if (!ValueKnown(first)) { - // First value is not known and second is not a special caase + } else if (second_value == 0) { + *out = MakeDim(second); + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { *out = UnknownDim(); - return Status::OK(); + } else { + if (first_value <= second_value) { + *out = first; + } else { + *out = MakeDim(second); + } } + return Status::OK(); +} - const int64 product = first_value * second_value; - if (product < 0) { - return errors::InvalidArgument("Negative dimension size from multiplying ", - first_value, " and ", second_value); +Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second, + const Dimension** out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); + } else { + if (first_value >= second_value) { + *out = first; + } else { + *out = MakeDim(second); + } } - *out = MakeDim(product); return Status::OK(); } |