diff options
author | 2016-08-04 15:15:57 -0800 | |
---|---|---|
committer | 2016-08-04 16:31:16 -0700 | |
commit | ee9241825d80bf295963ac2fad4dfa0fc9a7b998 (patch) | |
tree | 81260bea9c5328bd7c12fc4729c646e332115fb6 /tensorflow/core/framework | |
parent | 21038467d71be31193715f7b023e252c0c5e2b05 (diff) |
Add C++ shape inference for SVD.
This also adds Min(), Max(), and Subtract() operators and a few convenience methods to the InferenceContext. Change test utils to emit a human readable error message in case the user forgot to set the inference function.
Refactored shape_inference* a bit to enforce the invariant that a Dimension or DimensionOrConstant is always non-negative or equal to InferenceContext::kUnknownDim.
This made it possible to tighten & simplify the arithmetic operations a bit.
Change: 129385995
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 163 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 75 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference_test.cc | 184 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference_testutil.cc | 5 |
4 files changed, 320 insertions, 107 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index e44d921d5d..9c90bfe0f5 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -387,12 +387,6 @@ Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in, return ReturnCreatedShape(dims, out); } -const Dimension* InferenceContext::GetDimension(const DimensionOrConstant& d) { - if (d.dim != nullptr) return d.dim; - DCHECK(d.val >= 0 || d.val == kUnknownDim); - return MakeDim(d.val); -} - const Shape* InferenceContext::MakeShape( const std::vector<const Dimension*>& dims) { all_shapes_.push_back(new Shape(dims)); @@ -404,7 +398,7 @@ const Shape* InferenceContext::MakeShape( std::vector<const Dimension*> dims_actual; dims_actual.reserve(dims.size()); for (const DimensionOrConstant& d : dims) { - dims_actual.push_back(GetDimension(d)); + dims_actual.push_back(MakeDim(d)); } return MakeShape(dims_actual); } @@ -488,11 +482,6 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, return ReturnCreatedShape(dims, out); } -const Dimension* InferenceContext::MakeDim(int64 value) { - all_dims_.push_back(new Dimension(value)); - return all_dims_.back(); -} - // Returns a new dimension whose value is given by a scalar input tensor. Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) { const Tensor* t = input_tensor(idx); @@ -522,11 +511,6 @@ Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) { return Status::OK(); } -const Dimension* InferenceContext::UnknownDim() { - all_dims_.push_back(new Dimension()); - return all_dims_.back(); -} - Status InferenceContext::Divide(const Dimension* dividend, int64 divisor, const Dimension** out) { if (divisor == 1) { @@ -535,6 +519,10 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor, *out = UnknownDim(); } else { const int64 v = Value(dividend); + if (divisor <= 0) { + return errors::InvalidArgument("Divisor must be positive but is ", + divisor); + } if ((v % divisor) != 0) { return errors::InvalidArgument("Dimension size must be divisible by ", divisor, " but is ", v); @@ -546,87 +534,112 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor, Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second, const Dimension** out) { - const int64 second_value = - second.dim == nullptr ? second.val : Value(second.dim); - if (second.dim != nullptr && !ValueKnown(second.dim)) { - *out = UnknownDim(); + const int64 first_value = Value(first); + const int64 second_value = Value(second); + // Special cases. + if (first_value == 0) { + *out = MakeDim(second); } else if (second_value == 0) { - *out = first; - } else if (!ValueKnown(first)) { + *out = MakeDim(first); + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { *out = UnknownDim(); } else { - const int64 v = Value(first); - const int64 sum = v + second_value; - if (second_value > 0 && sum < 0) { - return errors::InvalidArgument("Dimension size overflow from adding ", v, - " and ", second_value); - } else if (second_value < 0 && sum < 0) { - return errors::InvalidArgument("Negative dimension size from adding ", v, - " and ", second_value); + // Invariant: Both values are known and positive. + const int64 sum = first_value + second_value; + if (sum < 0) { + return errors::InvalidArgument("Dimension size overflow from adding ", + first_value, " and ", second_value); } *out = MakeDim(sum); } return Status::OK(); } -Status InferenceContext::Multiply(const Dimension* first, +Status InferenceContext::Subtract(const Dimension* first, DimensionOrConstant second, const Dimension** out) { - int64 first_value = -1; - // Special cases for multiply are when the values are 0 or 1. - if (ValueKnown(first)) { - first_value = Value(first); - if (first_value == 0) { - *out = MakeDim(0); - return Status::OK(); - } - - // Output is whatever the second value is. - if (first_value == 1) { - *out = GetDimension(second); - return Status::OK(); + const int64 first_value = Value(first); + const int64 second_value = Value(second); + // Special cases. + if (second_value == 0) { + *out = MakeDim(first); + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); + } else { + // Invariant: Both values are known, first_value is non-negative, and + // second_value is positive. + if (first_value < second_value) { + return errors::InvalidArgument( + "Negative dimension size caused by subtracting ", second_value, + " from ", first_value); } + *out = MakeDim(first_value - second_value); } + return Status::OK(); +} - // Same check for when the second argument is a known value. - // First find out if the value is known from DimOrConstant. - int64 second_value; - if (second.dim == nullptr) { - second_value = second.val; +Status InferenceContext::Multiply(const Dimension* first, + DimensionOrConstant second, + const Dimension** out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + // Special cases. + if (first_value == 0) { + *out = first; + } else if (second_value == 0) { + *out = MakeDim(second); + } else if (first_value == 1) { + *out = MakeDim(second); + } else if (second_value == 1) { + *out = first; + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); } else { - if (!ValueKnown(second.dim)) { - // Second value is not known and first is not a special caase - *out = UnknownDim(); - return Status::OK(); + // Invariant: Both values are known and and greater than 1. + const int64 product = first_value * second_value; + if (product < 0) { + return errors::InvalidArgument( + "Negative dimension size caused by overflow when multiplying ", + first_value, " and ", second_value); } - second_value = Value(second.dim); - } - - // Now that we know whether the value is known, apply the special - // casing. - if (second_value == 0) { - *out = MakeDim(0); - return Status::OK(); + *out = MakeDim(product); } + return Status::OK(); +} - // Output is whatever the first value is. - if (second_value == 1) { +Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second, + const Dimension** out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + if (first_value == 0) { *out = first; - return Status::OK(); - } - - if (!ValueKnown(first)) { - // First value is not known and second is not a special caase + } else if (second_value == 0) { + *out = MakeDim(second); + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { *out = UnknownDim(); - return Status::OK(); + } else { + if (first_value <= second_value) { + *out = first; + } else { + *out = MakeDim(second); + } } + return Status::OK(); +} - const int64 product = first_value * second_value; - if (product < 0) { - return errors::InvalidArgument("Negative dimension size from multiplying ", - first_value, " and ", second_value); +Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second, + const Dimension** out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); + } else { + if (first_value >= second_value) { + *out = first; + } else { + *out = MakeDim(second); + } } - *out = MakeDim(product); return Status::OK(); } diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 1aa51f5017..f35c8a4c81 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -46,7 +46,7 @@ class Dimension { class Shape { private: Shape(); - Shape(std::vector<const Dimension*> dims); + Shape(const std::vector<const Dimension*>& dims); ~Shape() {} const int32 rank_; @@ -61,13 +61,17 @@ class Shape { struct DimensionOrConstant { public: // Intentionally not explicit. - DimensionOrConstant(const Dimension* dim) : dim(dim) {} + DimensionOrConstant(const Dimension* dim); // val must be non-negative or InferenceContext::kUnknownDim. - DimensionOrConstant(int64 val) : val(val) {} + DimensionOrConstant(int64 val); - const Dimension* dim = nullptr; - int64 val = 0; + // dim takes precedence. If dim != nullptr, val is ignored. + const Dimension* dim; + int64 val; + + private: + DimensionOrConstant(); }; // Note: This is experimental support for op shape inference in C++. Shape @@ -81,8 +85,8 @@ struct DimensionOrConstant { // by the InferenceContext. class InferenceContext { public: - static constexpr int32 kUnknownRank = -1; static constexpr int64 kUnknownDim = -1; + static constexpr int32 kUnknownRank = -1; // This is a temporary constructor used for initial testing. // @@ -127,8 +131,12 @@ class InferenceContext { } int32 Rank(const Shape* s) { return s->rank_; } bool RankKnown(const Shape* s) { return Rank(s) != kUnknownRank; } - int64 Value(const Dimension* d) { return d->value_; } - bool ValueKnown(const Dimension* d) { return Value(d) != kUnknownDim; } + inline int64 Value(DimensionOrConstant d) { + return d.dim ? d.dim->value_ : d.val; + } + inline bool ValueKnown(DimensionOrConstant d) { + return Value(d) != kUnknownDim; + } // Returns true if the rank and all dimensions of the Shape are known. bool FullyDefined(const Shape* s); @@ -232,8 +240,15 @@ class InferenceContext { // Returns a new dimension of the given size. The returned value is owned by // this context. - const Dimension* MakeDim(int64 value); - const Dimension* UnknownDim(); + inline const Dimension* MakeDim(DimensionOrConstant d) { + if (d.dim) { + return d.dim; + } else { + all_dims_.push_back(new Dimension(d.val)); + return all_dims_.back(); + } + } + inline const Dimension* UnknownDim() { return MakeDim(kUnknownDim); } // Returns a new dimension whose value is given by a scalar input tensor. // The input tensor must be in host memory, since it is dereferenced to get @@ -247,7 +262,8 @@ class InferenceContext { Status GetAttr(StringPiece attr_name, T* value) const; // Returns in <out> the result of dividing <dividend> by <divisor>. - // Returns an error if <divisor> does not evenly divide <dividend>. + // Returns an error if <divisor> is not positive or does not evenly + // divide <dividend>. Status Divide(const Dimension* dividend, int64 divisor, const Dimension** out); @@ -255,10 +271,25 @@ class InferenceContext { Status Add(const Dimension* first, DimensionOrConstant second, const Dimension** out); + // Returns in <out> the dimension that is <first> minus <second>. + Status Subtract(const Dimension* first, DimensionOrConstant second, + const Dimension** out); + // Returns in <out> the product of <first> and <second>. Status Multiply(const Dimension* first, DimensionOrConstant second, const Dimension** out); + // Returns in <out> the minimum of <first> and <second>. If either <first> or + // <second> is zero the results is zero. Otherwise, if either <first> or + // <second> is unknown the results is unknown. + Status Min(const Dimension* first, DimensionOrConstant second, + const Dimension** out); + + // Returns in <out> the maximum of <first> and <second>. If either <first> or + // <second> is unknown the results is unknown. + Status Max(const Dimension* first, DimensionOrConstant second, + const Dimension** out); + Status construction_status() const { return construction_status_; } // Validates that 'dim' has a known value, and prints an error @@ -307,12 +338,30 @@ class InferenceContext { // Template and inline method implementations, please ignore inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {} -inline Dimension::Dimension(int64 value) : value_(value) {} +inline Dimension::Dimension(int64 value) : value_(value) { + DCHECK(value >= 0 || value == InferenceContext::kUnknownDim) + << "Dimension must be non-negative or equal to " + "InferenceContext::kUnknownDim but got" + << value; +} inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {} -inline Shape::Shape(const std::vector<const Dimension*> dims) +inline Shape::Shape(const std::vector<const Dimension*>& dims) : rank_(dims.size()), dims_(dims) {} +inline DimensionOrConstant::DimensionOrConstant(const Dimension* dim) + : dim(dim) { + DCHECK(dim != nullptr) << "Internal error: Got nullptr for Dimension."; +} + +inline DimensionOrConstant::DimensionOrConstant(int64 val) + : dim(nullptr), val(val) { + DCHECK(val >= 0 || val == InferenceContext::kUnknownDim) + << "Dimension must be non-negative or equal to " + "InferenceContext::kUnknownDim but got" + << val; +} + template <class T> Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { return GetNodeAttr(node_def_, attr_name, value); diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index fffb25da6d..1ecba2839a 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -36,6 +36,19 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) { return op_reg_data.op_def; } +TEST(ShapeInferenceTest, DimensionOrConstant) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(1, 1), {"?"}, {}); + EXPECT_EQ(InferenceContext::kUnknownDim, + c.Value(InferenceContext::kUnknownDim)); + EXPECT_EQ(1, c.Value(1)); + +#ifndef NDEBUG + // Only run death test if DCHECKS are enabled. + EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to"); +#endif +} + TEST(ShapeInferenceTest, RankAndDimInspection) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 2), {"?", "[1,?,3]", "[]"}, {}); @@ -767,15 +780,20 @@ TEST(ShapeInferenceTest, Divide) { EXPECT_EQ("Dimension size must be divisible by 5 but is 6", c.Divide(d_6, 5, &out).error_message()); + EXPECT_EQ("Divisor must be positive but is 0", + c.Divide(d_6, 0, &out).error_message()); + EXPECT_EQ("Divisor must be positive but is -1", + c.Divide(d_6, -1, &out).error_message()); } TEST(ShapeInferenceTest, Add) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?]"}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0]"}, {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); auto d_unknown = c.Dim(s, 1); + auto d_0 = c.Dim(s, 2); // Adding non-zero to unknown gives new unknown. const Dimension* out; @@ -790,16 +808,14 @@ TEST(ShapeInferenceTest, Add) { EXPECT_TRUE(out == d_6); // Adding dimension with value 0 to anything gives input. - EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0), &out).ok()); + EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok()); EXPECT_TRUE(out == d_unknown); - EXPECT_TRUE(c.Add(d_6, c.MakeDim(0), &out).ok()); + EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok()); EXPECT_TRUE(out == d_6); // Test addition. EXPECT_TRUE(c.Add(d_6, 2, &out).ok()); EXPECT_EQ("8", c.DebugString(out)); - EXPECT_TRUE(c.Add(d_6, -6, &out).ok()); - EXPECT_EQ("0", c.DebugString(out)); EXPECT_TRUE(c.Add(d_6, std::numeric_limits<int64>::max() - 6, &out).ok()); EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out)); @@ -811,14 +827,62 @@ TEST(ShapeInferenceTest, Add) { EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out)); EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok()); EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Add(d_0, d_6, &out).ok()); + EXPECT_TRUE(out == d_6); - EXPECT_EQ("Negative dimension size from adding 6 and -7", - c.Add(d_6, -7, &out).error_message()); EXPECT_EQ( "Dimension size overflow from adding 6 and 9223372036854775802", c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message()); } +TEST(ShapeInferenceTest, Subtract) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0,5]"}, {}); + + auto s = c.input(0); + auto d_6 = c.Dim(s, 0); + auto d_unknown = c.Dim(s, 1); + auto d_0 = c.Dim(s, 2); + auto d_5 = c.Dim(s, 3); + + // Subtracting non-zero from unknown gives new unknown. + const Dimension* out; + EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(out != d_unknown); + + // Subtracting 0 from anything gives input. + EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok()); + EXPECT_TRUE(out == d_unknown); + EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok()); + EXPECT_TRUE(out == d_6); + + // Subtracting dimension with value 0 from anything gives input. + EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok()); + EXPECT_TRUE(out == d_unknown); + EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok()); + EXPECT_TRUE(out == d_6); + + // Test subtraction. + EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok()); + EXPECT_EQ("4", c.DebugString(out)); + EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + + // Test subtraction using dimension as second value. + EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok()); + EXPECT_EQ("4", c.DebugString(out)); + EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok()); + EXPECT_EQ("1", c.DebugString(out)); + EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok()); + EXPECT_TRUE(out == d_6); + + EXPECT_EQ("Negative dimension size caused by subtracting 6 from 5", + c.Subtract(d_5, d_6, &out).error_message()); +} + TEST(ShapeInferenceTest, Multiply) { NodeDef def; InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0,1]"}, {}); @@ -831,7 +895,7 @@ TEST(ShapeInferenceTest, Multiply) { // Multiplying non-zero to unknown gives new unknown. const Dimension* out; - EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok()); + EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok()); EXPECT_EQ("?", c.DebugString(out)); // Multiplying 0 to anything gives 0. @@ -844,19 +908,19 @@ TEST(ShapeInferenceTest, Multiply) { // Multiplying 1 to anything gives the original. // (unknown -> unknown) - EXPECT_TRUE(c.Multiply(d_unknown, static_cast<int64>(1), &out).ok()); - EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok()); + EXPECT_EQ(d_unknown, out); EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok()); - EXPECT_EQ("?", c.DebugString(out)); + EXPECT_EQ(d_unknown, out); EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok()); - EXPECT_EQ("?", c.DebugString(out)); + EXPECT_EQ(d_unknown, out); // (known -> known) - EXPECT_TRUE(c.Multiply(d_6, static_cast<int64>(1), &out).ok()); - EXPECT_EQ("6", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok()); + EXPECT_EQ(d_6, out); EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok()); - EXPECT_EQ("6", c.DebugString(out)); + EXPECT_EQ(d_6, out); EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok()); - EXPECT_EQ("6", c.DebugString(out)); + EXPECT_EQ(d_6, out); // Test multiplication. EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok()); @@ -869,9 +933,6 @@ TEST(ShapeInferenceTest, Multiply) { EXPECT_EQ("12", c.DebugString(out)); EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok()); EXPECT_EQ("?", c.DebugString(out)); - - EXPECT_EQ("Negative dimension size from multiplying 6 and -7", - c.Multiply(d_6, -7, &out).error_message()); } TEST(ShapeInferenceTest, FullyDefined) { @@ -895,5 +956,90 @@ TEST(ShapeInferenceTest, ValidateKnownDim) { EXPECT_TRUE(c.ValidateKnownDim(c.Dim(c.Matrix(1, 2), 0), "known").ok()); } +TEST(ShapeInferenceTest, Min) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(1, 2), {"[1,2,?,0]"}, {}); + + auto s = c.input(0); + auto d_1 = c.Dim(s, 0); + auto d_2 = c.Dim(s, 1); + auto d_unknown = c.Dim(s, 2); + auto d_0 = c.Dim(s, 3); + + // Minimum involving zero and unknown returns zero. + const Dimension* out; + EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok()); + EXPECT_EQ(d_0, out); + EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok()); + EXPECT_EQ(d_0, out); + EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + + // Minimum involving unknowns and non-zeros gives new unknown. + EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + + // Minimum with constant second arg. + EXPECT_TRUE(c.Min(d_1, 1, &out).ok()); + EXPECT_EQ(d_1, out); + EXPECT_TRUE(c.Min(d_1, 3, &out).ok()); + EXPECT_EQ(d_1, out); + EXPECT_TRUE(c.Min(d_2, 1, &out).ok()); + EXPECT_EQ("1", c.DebugString(out)); + + // Minimum with two dimensions. + EXPECT_TRUE(c.Min(d_1, d_1, &out).ok()); + EXPECT_EQ(d_1, out); + EXPECT_TRUE(c.Min(d_1, d_2, &out).ok()); + EXPECT_EQ(d_1, out); + EXPECT_TRUE(c.Min(d_2, d_1, &out).ok()); + EXPECT_EQ(d_1, out); + EXPECT_TRUE(c.Min(d_2, d_2, &out).ok()); + EXPECT_EQ(d_2, out); +} + +TEST(ShapeInferenceTest, Max) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(1, 2), {"[1,2,?]"}, {}); + + auto s = c.input(0); + auto d_1 = c.Dim(s, 0); + auto d_2 = c.Dim(s, 1); + auto d_unknown = c.Dim(s, 2); + + // Maximum involving unknowns gives new unknown. + const Dimension* out; + EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + + // Maximum with constant second arg. + EXPECT_TRUE(c.Max(d_1, 1, &out).ok()); + EXPECT_EQ(d_1, out); + EXPECT_TRUE(c.Max(d_2, 1, &out).ok()); + EXPECT_EQ(d_2, out); + EXPECT_TRUE(c.Max(d_2, 3, &out).ok()); + EXPECT_EQ("3", c.DebugString(out)); + + // Maximum with two dimensions. + EXPECT_TRUE(c.Max(d_1, d_1, &out).ok()); + EXPECT_EQ(d_1, out); + EXPECT_TRUE(c.Max(d_1, d_2, &out).ok()); + EXPECT_EQ(d_2, out); + EXPECT_TRUE(c.Max(d_2, d_1, &out).ok()); + EXPECT_EQ(d_2, out); + EXPECT_TRUE(c.Max(d_2, d_2, &out).ok()); + EXPECT_EQ(d_2, out); +} + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc index c1e55d032d..60a9cb101f 100644 --- a/tensorflow/core/framework/shape_inference_testutil.cc +++ b/tensorflow/core/framework/shape_inference_testutil.cc @@ -40,6 +40,11 @@ Status InferShapes(ShapeInferenceTestOp op, const string& ins, shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def, ins_v, op.input_tensors); TF_RETURN_IF_ERROR(c.construction_status()); + if (op_reg_data->shape_inference_fn == nullptr) { + return errors::InvalidArgument( + "No shape inference function exists for op '", op.name, + "', did you forget to define it?"); + } TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c)); const int num_outputs = c.num_outputs(); |