aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-04 15:15:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-04 16:31:16 -0700
commitee9241825d80bf295963ac2fad4dfa0fc9a7b998 (patch)
tree81260bea9c5328bd7c12fc4729c646e332115fb6 /tensorflow/core/framework
parent21038467d71be31193715f7b023e252c0c5e2b05 (diff)
Add C++ shape inference for SVD.
This also adds Min(), Max(), and Subtract() operators and a few convenience methods to the InferenceContext. Change test utils to emit a human readable error message in case the user forgot to set the inference function. Refactored shape_inference* a bit to enforce the invariant that a Dimension or DimensionOrConstant is always non-negative or equal to InferenceContext::kUnknownDim. This made it possible to tighten & simplify the arithmetic operations a bit. Change: 129385995
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/shape_inference.cc163
-rw-r--r--tensorflow/core/framework/shape_inference.h75
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc184
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.cc5
4 files changed, 320 insertions, 107 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index e44d921d5d..9c90bfe0f5 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -387,12 +387,6 @@ Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in,
return ReturnCreatedShape(dims, out);
}
-const Dimension* InferenceContext::GetDimension(const DimensionOrConstant& d) {
- if (d.dim != nullptr) return d.dim;
- DCHECK(d.val >= 0 || d.val == kUnknownDim);
- return MakeDim(d.val);
-}
-
const Shape* InferenceContext::MakeShape(
const std::vector<const Dimension*>& dims) {
all_shapes_.push_back(new Shape(dims));
@@ -404,7 +398,7 @@ const Shape* InferenceContext::MakeShape(
std::vector<const Dimension*> dims_actual;
dims_actual.reserve(dims.size());
for (const DimensionOrConstant& d : dims) {
- dims_actual.push_back(GetDimension(d));
+ dims_actual.push_back(MakeDim(d));
}
return MakeShape(dims_actual);
}
@@ -488,11 +482,6 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
return ReturnCreatedShape(dims, out);
}
-const Dimension* InferenceContext::MakeDim(int64 value) {
- all_dims_.push_back(new Dimension(value));
- return all_dims_.back();
-}
-
// Returns a new dimension whose value is given by a scalar input tensor.
Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
const Tensor* t = input_tensor(idx);
@@ -522,11 +511,6 @@ Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
return Status::OK();
}
-const Dimension* InferenceContext::UnknownDim() {
- all_dims_.push_back(new Dimension());
- return all_dims_.back();
-}
-
Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
const Dimension** out) {
if (divisor == 1) {
@@ -535,6 +519,10 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
*out = UnknownDim();
} else {
const int64 v = Value(dividend);
+ if (divisor <= 0) {
+ return errors::InvalidArgument("Divisor must be positive but is ",
+ divisor);
+ }
if ((v % divisor) != 0) {
return errors::InvalidArgument("Dimension size must be divisible by ",
divisor, " but is ", v);
@@ -546,87 +534,112 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second,
const Dimension** out) {
- const int64 second_value =
- second.dim == nullptr ? second.val : Value(second.dim);
- if (second.dim != nullptr && !ValueKnown(second.dim)) {
- *out = UnknownDim();
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ // Special cases.
+ if (first_value == 0) {
+ *out = MakeDim(second);
} else if (second_value == 0) {
- *out = first;
- } else if (!ValueKnown(first)) {
+ *out = MakeDim(first);
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim();
} else {
- const int64 v = Value(first);
- const int64 sum = v + second_value;
- if (second_value > 0 && sum < 0) {
- return errors::InvalidArgument("Dimension size overflow from adding ", v,
- " and ", second_value);
- } else if (second_value < 0 && sum < 0) {
- return errors::InvalidArgument("Negative dimension size from adding ", v,
- " and ", second_value);
+ // Invariant: Both values are known and positive.
+ const int64 sum = first_value + second_value;
+ if (sum < 0) {
+ return errors::InvalidArgument("Dimension size overflow from adding ",
+ first_value, " and ", second_value);
}
*out = MakeDim(sum);
}
return Status::OK();
}
-Status InferenceContext::Multiply(const Dimension* first,
+Status InferenceContext::Subtract(const Dimension* first,
DimensionOrConstant second,
const Dimension** out) {
- int64 first_value = -1;
- // Special cases for multiply are when the values are 0 or 1.
- if (ValueKnown(first)) {
- first_value = Value(first);
- if (first_value == 0) {
- *out = MakeDim(0);
- return Status::OK();
- }
-
- // Output is whatever the second value is.
- if (first_value == 1) {
- *out = GetDimension(second);
- return Status::OK();
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ // Special cases.
+ if (second_value == 0) {
+ *out = MakeDim(first);
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
+ *out = UnknownDim();
+ } else {
+ // Invariant: Both values are known, first_value is non-negative, and
+ // second_value is positive.
+ if (first_value < second_value) {
+ return errors::InvalidArgument(
+ "Negative dimension size caused by subtracting ", second_value,
+ " from ", first_value);
}
+ *out = MakeDim(first_value - second_value);
}
+ return Status::OK();
+}
- // Same check for when the second argument is a known value.
- // First find out if the value is known from DimOrConstant.
- int64 second_value;
- if (second.dim == nullptr) {
- second_value = second.val;
+Status InferenceContext::Multiply(const Dimension* first,
+ DimensionOrConstant second,
+ const Dimension** out) {
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ // Special cases.
+ if (first_value == 0) {
+ *out = first;
+ } else if (second_value == 0) {
+ *out = MakeDim(second);
+ } else if (first_value == 1) {
+ *out = MakeDim(second);
+ } else if (second_value == 1) {
+ *out = first;
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
+ *out = UnknownDim();
} else {
- if (!ValueKnown(second.dim)) {
- // Second value is not known and first is not a special caase
- *out = UnknownDim();
- return Status::OK();
+ // Invariant: Both values are known and and greater than 1.
+ const int64 product = first_value * second_value;
+ if (product < 0) {
+ return errors::InvalidArgument(
+ "Negative dimension size caused by overflow when multiplying ",
+ first_value, " and ", second_value);
}
- second_value = Value(second.dim);
- }
-
- // Now that we know whether the value is known, apply the special
- // casing.
- if (second_value == 0) {
- *out = MakeDim(0);
- return Status::OK();
+ *out = MakeDim(product);
}
+ return Status::OK();
+}
- // Output is whatever the first value is.
- if (second_value == 1) {
+Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out) {
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ if (first_value == 0) {
*out = first;
- return Status::OK();
- }
-
- if (!ValueKnown(first)) {
- // First value is not known and second is not a special caase
+ } else if (second_value == 0) {
+ *out = MakeDim(second);
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim();
- return Status::OK();
+ } else {
+ if (first_value <= second_value) {
+ *out = first;
+ } else {
+ *out = MakeDim(second);
+ }
}
+ return Status::OK();
+}
- const int64 product = first_value * second_value;
- if (product < 0) {
- return errors::InvalidArgument("Negative dimension size from multiplying ",
- first_value, " and ", second_value);
+Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out) {
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ if (first_value == kUnknownDim || second_value == kUnknownDim) {
+ *out = UnknownDim();
+ } else {
+ if (first_value >= second_value) {
+ *out = first;
+ } else {
+ *out = MakeDim(second);
+ }
}
- *out = MakeDim(product);
return Status::OK();
}
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 1aa51f5017..f35c8a4c81 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -46,7 +46,7 @@ class Dimension {
class Shape {
private:
Shape();
- Shape(std::vector<const Dimension*> dims);
+ Shape(const std::vector<const Dimension*>& dims);
~Shape() {}
const int32 rank_;
@@ -61,13 +61,17 @@ class Shape {
struct DimensionOrConstant {
public:
// Intentionally not explicit.
- DimensionOrConstant(const Dimension* dim) : dim(dim) {}
+ DimensionOrConstant(const Dimension* dim);
// val must be non-negative or InferenceContext::kUnknownDim.
- DimensionOrConstant(int64 val) : val(val) {}
+ DimensionOrConstant(int64 val);
- const Dimension* dim = nullptr;
- int64 val = 0;
+ // dim takes precedence. If dim != nullptr, val is ignored.
+ const Dimension* dim;
+ int64 val;
+
+ private:
+ DimensionOrConstant();
};
// Note: This is experimental support for op shape inference in C++. Shape
@@ -81,8 +85,8 @@ struct DimensionOrConstant {
// by the InferenceContext.
class InferenceContext {
public:
- static constexpr int32 kUnknownRank = -1;
static constexpr int64 kUnknownDim = -1;
+ static constexpr int32 kUnknownRank = -1;
// This is a temporary constructor used for initial testing.
//
@@ -127,8 +131,12 @@ class InferenceContext {
}
int32 Rank(const Shape* s) { return s->rank_; }
bool RankKnown(const Shape* s) { return Rank(s) != kUnknownRank; }
- int64 Value(const Dimension* d) { return d->value_; }
- bool ValueKnown(const Dimension* d) { return Value(d) != kUnknownDim; }
+ inline int64 Value(DimensionOrConstant d) {
+ return d.dim ? d.dim->value_ : d.val;
+ }
+ inline bool ValueKnown(DimensionOrConstant d) {
+ return Value(d) != kUnknownDim;
+ }
// Returns true if the rank and all dimensions of the Shape are known.
bool FullyDefined(const Shape* s);
@@ -232,8 +240,15 @@ class InferenceContext {
// Returns a new dimension of the given size. The returned value is owned by
// this context.
- const Dimension* MakeDim(int64 value);
- const Dimension* UnknownDim();
+ inline const Dimension* MakeDim(DimensionOrConstant d) {
+ if (d.dim) {
+ return d.dim;
+ } else {
+ all_dims_.push_back(new Dimension(d.val));
+ return all_dims_.back();
+ }
+ }
+ inline const Dimension* UnknownDim() { return MakeDim(kUnknownDim); }
// Returns a new dimension whose value is given by a scalar input tensor.
// The input tensor must be in host memory, since it is dereferenced to get
@@ -247,7 +262,8 @@ class InferenceContext {
Status GetAttr(StringPiece attr_name, T* value) const;
// Returns in <out> the result of dividing <dividend> by <divisor>.
- // Returns an error if <divisor> does not evenly divide <dividend>.
+ // Returns an error if <divisor> is not positive or does not evenly
+ // divide <dividend>.
Status Divide(const Dimension* dividend, int64 divisor,
const Dimension** out);
@@ -255,10 +271,25 @@ class InferenceContext {
Status Add(const Dimension* first, DimensionOrConstant second,
const Dimension** out);
+ // Returns in <out> the dimension that is <first> minus <second>.
+ Status Subtract(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out);
+
// Returns in <out> the product of <first> and <second>.
Status Multiply(const Dimension* first, DimensionOrConstant second,
const Dimension** out);
+ // Returns in <out> the minimum of <first> and <second>. If either <first> or
+ // <second> is zero the results is zero. Otherwise, if either <first> or
+ // <second> is unknown the results is unknown.
+ Status Min(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out);
+
+ // Returns in <out> the maximum of <first> and <second>. If either <first> or
+ // <second> is unknown the results is unknown.
+ Status Max(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out);
+
Status construction_status() const { return construction_status_; }
// Validates that 'dim' has a known value, and prints an error
@@ -307,12 +338,30 @@ class InferenceContext {
// Template and inline method implementations, please ignore
inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
-inline Dimension::Dimension(int64 value) : value_(value) {}
+inline Dimension::Dimension(int64 value) : value_(value) {
+ DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
+ << "Dimension must be non-negative or equal to "
+ "InferenceContext::kUnknownDim but got"
+ << value;
+}
inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
-inline Shape::Shape(const std::vector<const Dimension*> dims)
+inline Shape::Shape(const std::vector<const Dimension*>& dims)
: rank_(dims.size()), dims_(dims) {}
+inline DimensionOrConstant::DimensionOrConstant(const Dimension* dim)
+ : dim(dim) {
+ DCHECK(dim != nullptr) << "Internal error: Got nullptr for Dimension.";
+}
+
+inline DimensionOrConstant::DimensionOrConstant(int64 val)
+ : dim(nullptr), val(val) {
+ DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
+ << "Dimension must be non-negative or equal to "
+ "InferenceContext::kUnknownDim but got"
+ << val;
+}
+
template <class T>
Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
return GetNodeAttr(node_def_, attr_name, value);
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index fffb25da6d..1ecba2839a 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -36,6 +36,19 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
return op_reg_data.op_def;
}
+TEST(ShapeInferenceTest, DimensionOrConstant) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(1, 1), {"?"}, {});
+ EXPECT_EQ(InferenceContext::kUnknownDim,
+ c.Value(InferenceContext::kUnknownDim));
+ EXPECT_EQ(1, c.Value(1));
+
+#ifndef NDEBUG
+ // Only run death test if DCHECKS are enabled.
+ EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to");
+#endif
+}
+
TEST(ShapeInferenceTest, RankAndDimInspection) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {"?", "[1,?,3]", "[]"}, {});
@@ -767,15 +780,20 @@ TEST(ShapeInferenceTest, Divide) {
EXPECT_EQ("Dimension size must be divisible by 5 but is 6",
c.Divide(d_6, 5, &out).error_message());
+ EXPECT_EQ("Divisor must be positive but is 0",
+ c.Divide(d_6, 0, &out).error_message());
+ EXPECT_EQ("Divisor must be positive but is -1",
+ c.Divide(d_6, -1, &out).error_message());
}
TEST(ShapeInferenceTest, Add) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?]"}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0]"}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
auto d_unknown = c.Dim(s, 1);
+ auto d_0 = c.Dim(s, 2);
// Adding non-zero to unknown gives new unknown.
const Dimension* out;
@@ -790,16 +808,14 @@ TEST(ShapeInferenceTest, Add) {
EXPECT_TRUE(out == d_6);
// Adding dimension with value 0 to anything gives input.
- EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0), &out).ok());
+ EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok());
EXPECT_TRUE(out == d_unknown);
- EXPECT_TRUE(c.Add(d_6, c.MakeDim(0), &out).ok());
+ EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok());
EXPECT_TRUE(out == d_6);
// Test addition.
EXPECT_TRUE(c.Add(d_6, 2, &out).ok());
EXPECT_EQ("8", c.DebugString(out));
- EXPECT_TRUE(c.Add(d_6, -6, &out).ok());
- EXPECT_EQ("0", c.DebugString(out));
EXPECT_TRUE(c.Add(d_6, std::numeric_limits<int64>::max() - 6, &out).ok());
EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
@@ -811,14 +827,62 @@ TEST(ShapeInferenceTest, Add) {
EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok());
EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Add(d_0, d_6, &out).ok());
+ EXPECT_TRUE(out == d_6);
- EXPECT_EQ("Negative dimension size from adding 6 and -7",
- c.Add(d_6, -7, &out).error_message());
EXPECT_EQ(
"Dimension size overflow from adding 6 and 9223372036854775802",
c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message());
}
+TEST(ShapeInferenceTest, Subtract) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0,5]"}, {});
+
+ auto s = c.input(0);
+ auto d_6 = c.Dim(s, 0);
+ auto d_unknown = c.Dim(s, 1);
+ auto d_0 = c.Dim(s, 2);
+ auto d_5 = c.Dim(s, 3);
+
+ // Subtracting non-zero from unknown gives new unknown.
+ const Dimension* out;
+ EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(out != d_unknown);
+
+ // Subtracting 0 from anything gives input.
+ EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok());
+ EXPECT_TRUE(out == d_unknown);
+ EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok());
+ EXPECT_TRUE(out == d_6);
+
+ // Subtracting dimension with value 0 from anything gives input.
+ EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok());
+ EXPECT_TRUE(out == d_unknown);
+ EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok());
+ EXPECT_TRUE(out == d_6);
+
+ // Test subtraction.
+ EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok());
+ EXPECT_EQ("4", c.DebugString(out));
+ EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok());
+ EXPECT_EQ("0", c.DebugString(out));
+
+ // Test subtraction using dimension as second value.
+ EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok());
+ EXPECT_EQ("4", c.DebugString(out));
+ EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok());
+ EXPECT_EQ("1", c.DebugString(out));
+ EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok());
+ EXPECT_TRUE(out == d_6);
+
+ EXPECT_EQ("Negative dimension size caused by subtracting 6 from 5",
+ c.Subtract(d_5, d_6, &out).error_message());
+}
+
TEST(ShapeInferenceTest, Multiply) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0,1]"}, {});
@@ -831,7 +895,7 @@ TEST(ShapeInferenceTest, Multiply) {
// Multiplying non-zero to unknown gives new unknown.
const Dimension* out;
- EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok());
+ EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
// Multiplying 0 to anything gives 0.
@@ -844,19 +908,19 @@ TEST(ShapeInferenceTest, Multiply) {
// Multiplying 1 to anything gives the original.
// (unknown -> unknown)
- EXPECT_TRUE(c.Multiply(d_unknown, static_cast<int64>(1), &out).ok());
- EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok());
+ EXPECT_EQ(d_unknown, out);
EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok());
- EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_EQ(d_unknown, out);
EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok());
- EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_EQ(d_unknown, out);
// (known -> known)
- EXPECT_TRUE(c.Multiply(d_6, static_cast<int64>(1), &out).ok());
- EXPECT_EQ("6", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok());
+ EXPECT_EQ(d_6, out);
EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok());
- EXPECT_EQ("6", c.DebugString(out));
+ EXPECT_EQ(d_6, out);
EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok());
- EXPECT_EQ("6", c.DebugString(out));
+ EXPECT_EQ(d_6, out);
// Test multiplication.
EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok());
@@ -869,9 +933,6 @@ TEST(ShapeInferenceTest, Multiply) {
EXPECT_EQ("12", c.DebugString(out));
EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok());
EXPECT_EQ("?", c.DebugString(out));
-
- EXPECT_EQ("Negative dimension size from multiplying 6 and -7",
- c.Multiply(d_6, -7, &out).error_message());
}
TEST(ShapeInferenceTest, FullyDefined) {
@@ -895,5 +956,90 @@ TEST(ShapeInferenceTest, ValidateKnownDim) {
EXPECT_TRUE(c.ValidateKnownDim(c.Dim(c.Matrix(1, 2), 0), "known").ok());
}
+TEST(ShapeInferenceTest, Min) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(1, 2), {"[1,2,?,0]"}, {});
+
+ auto s = c.input(0);
+ auto d_1 = c.Dim(s, 0);
+ auto d_2 = c.Dim(s, 1);
+ auto d_unknown = c.Dim(s, 2);
+ auto d_0 = c.Dim(s, 3);
+
+ // Minimum involving zero and unknown returns zero.
+ const Dimension* out;
+ EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok());
+ EXPECT_EQ(d_0, out);
+ EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok());
+ EXPECT_EQ(d_0, out);
+ EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok());
+ EXPECT_EQ("0", c.DebugString(out));
+ EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok());
+ EXPECT_EQ("0", c.DebugString(out));
+
+ // Minimum involving unknowns and non-zeros gives new unknown.
+ EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+
+ // Minimum with constant second arg.
+ EXPECT_TRUE(c.Min(d_1, 1, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Min(d_1, 3, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Min(d_2, 1, &out).ok());
+ EXPECT_EQ("1", c.DebugString(out));
+
+ // Minimum with two dimensions.
+ EXPECT_TRUE(c.Min(d_1, d_1, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Min(d_1, d_2, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Min(d_2, d_1, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Min(d_2, d_2, &out).ok());
+ EXPECT_EQ(d_2, out);
+}
+
+TEST(ShapeInferenceTest, Max) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(1, 2), {"[1,2,?]"}, {});
+
+ auto s = c.input(0);
+ auto d_1 = c.Dim(s, 0);
+ auto d_2 = c.Dim(s, 1);
+ auto d_unknown = c.Dim(s, 2);
+
+ // Maximum involving unknowns gives new unknown.
+ const Dimension* out;
+ EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+
+ // Maximum with constant second arg.
+ EXPECT_TRUE(c.Max(d_1, 1, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Max(d_2, 1, &out).ok());
+ EXPECT_EQ(d_2, out);
+ EXPECT_TRUE(c.Max(d_2, 3, &out).ok());
+ EXPECT_EQ("3", c.DebugString(out));
+
+ // Maximum with two dimensions.
+ EXPECT_TRUE(c.Max(d_1, d_1, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Max(d_1, d_2, &out).ok());
+ EXPECT_EQ(d_2, out);
+ EXPECT_TRUE(c.Max(d_2, d_1, &out).ok());
+ EXPECT_EQ(d_2, out);
+ EXPECT_TRUE(c.Max(d_2, d_2, &out).ok());
+ EXPECT_EQ(d_2, out);
+}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index c1e55d032d..60a9cb101f 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -40,6 +40,11 @@ Status InferShapes(ShapeInferenceTestOp op, const string& ins,
shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def, ins_v,
op.input_tensors);
TF_RETURN_IF_ERROR(c.construction_status());
+ if (op_reg_data->shape_inference_fn == nullptr) {
+ return errors::InvalidArgument(
+ "No shape inference function exists for op '", op.name,
+ "', did you forget to define it?");
+ }
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
const int num_outputs = c.num_outputs();