aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-04 15:15:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-04 16:31:16 -0700
commitee9241825d80bf295963ac2fad4dfa0fc9a7b998 (patch)
tree81260bea9c5328bd7c12fc4729c646e332115fb6 /tensorflow
parent21038467d71be31193715f7b023e252c0c5e2b05 (diff)
Add C++ shape inference for SVD.
This also adds Min(), Max(), and Subtract() operators and a few convenience methods to the InferenceContext. Change test utils to emit a human readable error message in case the user forgot to set the inference function. Refactored shape_inference* a bit to enforce the invariant that a Dimension or DimensionOrConstant is always non-negative or equal to InferenceContext::kUnknownDim. This made it possible to tighten & simplify the arithmetic operations a bit. Change: 129385995
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/framework/shape_inference.cc163
-rw-r--r--tensorflow/core/framework/shape_inference.h75
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc184
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.cc5
-rw-r--r--tensorflow/core/ops/array_ops.cc7
-rw-r--r--tensorflow/core/ops/linalg_ops.cc71
-rw-r--r--tensorflow/core/ops/linalg_ops_test.cc97
-rw-r--r--tensorflow/core/ops/sparse_ops.cc2
8 files changed, 489 insertions, 115 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index e44d921d5d..9c90bfe0f5 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -387,12 +387,6 @@ Status InferenceContext::ReplaceDim(const Shape* s, int dim_index_in,
return ReturnCreatedShape(dims, out);
}
-const Dimension* InferenceContext::GetDimension(const DimensionOrConstant& d) {
- if (d.dim != nullptr) return d.dim;
- DCHECK(d.val >= 0 || d.val == kUnknownDim);
- return MakeDim(d.val);
-}
-
const Shape* InferenceContext::MakeShape(
const std::vector<const Dimension*>& dims) {
all_shapes_.push_back(new Shape(dims));
@@ -404,7 +398,7 @@ const Shape* InferenceContext::MakeShape(
std::vector<const Dimension*> dims_actual;
dims_actual.reserve(dims.size());
for (const DimensionOrConstant& d : dims) {
- dims_actual.push_back(GetDimension(d));
+ dims_actual.push_back(MakeDim(d));
}
return MakeShape(dims_actual);
}
@@ -488,11 +482,6 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
return ReturnCreatedShape(dims, out);
}
-const Dimension* InferenceContext::MakeDim(int64 value) {
- all_dims_.push_back(new Dimension(value));
- return all_dims_.back();
-}
-
// Returns a new dimension whose value is given by a scalar input tensor.
Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
const Tensor* t = input_tensor(idx);
@@ -522,11 +511,6 @@ Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
return Status::OK();
}
-const Dimension* InferenceContext::UnknownDim() {
- all_dims_.push_back(new Dimension());
- return all_dims_.back();
-}
-
Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
const Dimension** out) {
if (divisor == 1) {
@@ -535,6 +519,10 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
*out = UnknownDim();
} else {
const int64 v = Value(dividend);
+ if (divisor <= 0) {
+ return errors::InvalidArgument("Divisor must be positive but is ",
+ divisor);
+ }
if ((v % divisor) != 0) {
return errors::InvalidArgument("Dimension size must be divisible by ",
divisor, " but is ", v);
@@ -546,87 +534,112 @@ Status InferenceContext::Divide(const Dimension* dividend, int64 divisor,
Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second,
const Dimension** out) {
- const int64 second_value =
- second.dim == nullptr ? second.val : Value(second.dim);
- if (second.dim != nullptr && !ValueKnown(second.dim)) {
- *out = UnknownDim();
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ // Special cases.
+ if (first_value == 0) {
+ *out = MakeDim(second);
} else if (second_value == 0) {
- *out = first;
- } else if (!ValueKnown(first)) {
+ *out = MakeDim(first);
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim();
} else {
- const int64 v = Value(first);
- const int64 sum = v + second_value;
- if (second_value > 0 && sum < 0) {
- return errors::InvalidArgument("Dimension size overflow from adding ", v,
- " and ", second_value);
- } else if (second_value < 0 && sum < 0) {
- return errors::InvalidArgument("Negative dimension size from adding ", v,
- " and ", second_value);
+ // Invariant: Both values are known and positive.
+ const int64 sum = first_value + second_value;
+ if (sum < 0) {
+ return errors::InvalidArgument("Dimension size overflow from adding ",
+ first_value, " and ", second_value);
}
*out = MakeDim(sum);
}
return Status::OK();
}
-Status InferenceContext::Multiply(const Dimension* first,
+Status InferenceContext::Subtract(const Dimension* first,
DimensionOrConstant second,
const Dimension** out) {
- int64 first_value = -1;
- // Special cases for multiply are when the values are 0 or 1.
- if (ValueKnown(first)) {
- first_value = Value(first);
- if (first_value == 0) {
- *out = MakeDim(0);
- return Status::OK();
- }
-
- // Output is whatever the second value is.
- if (first_value == 1) {
- *out = GetDimension(second);
- return Status::OK();
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ // Special cases.
+ if (second_value == 0) {
+ *out = MakeDim(first);
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
+ *out = UnknownDim();
+ } else {
+ // Invariant: Both values are known, first_value is non-negative, and
+ // second_value is positive.
+ if (first_value < second_value) {
+ return errors::InvalidArgument(
+ "Negative dimension size caused by subtracting ", second_value,
+ " from ", first_value);
}
+ *out = MakeDim(first_value - second_value);
}
+ return Status::OK();
+}
- // Same check for when the second argument is a known value.
- // First find out if the value is known from DimOrConstant.
- int64 second_value;
- if (second.dim == nullptr) {
- second_value = second.val;
+Status InferenceContext::Multiply(const Dimension* first,
+ DimensionOrConstant second,
+ const Dimension** out) {
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ // Special cases.
+ if (first_value == 0) {
+ *out = first;
+ } else if (second_value == 0) {
+ *out = MakeDim(second);
+ } else if (first_value == 1) {
+ *out = MakeDim(second);
+ } else if (second_value == 1) {
+ *out = first;
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
+ *out = UnknownDim();
} else {
- if (!ValueKnown(second.dim)) {
- // Second value is not known and first is not a special caase
- *out = UnknownDim();
- return Status::OK();
+ // Invariant: Both values are known and and greater than 1.
+ const int64 product = first_value * second_value;
+ if (product < 0) {
+ return errors::InvalidArgument(
+ "Negative dimension size caused by overflow when multiplying ",
+ first_value, " and ", second_value);
}
- second_value = Value(second.dim);
- }
-
- // Now that we know whether the value is known, apply the special
- // casing.
- if (second_value == 0) {
- *out = MakeDim(0);
- return Status::OK();
+ *out = MakeDim(product);
}
+ return Status::OK();
+}
- // Output is whatever the first value is.
- if (second_value == 1) {
+Status InferenceContext::Min(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out) {
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ if (first_value == 0) {
*out = first;
- return Status::OK();
- }
-
- if (!ValueKnown(first)) {
- // First value is not known and second is not a special caase
+ } else if (second_value == 0) {
+ *out = MakeDim(second);
+ } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim();
- return Status::OK();
+ } else {
+ if (first_value <= second_value) {
+ *out = first;
+ } else {
+ *out = MakeDim(second);
+ }
}
+ return Status::OK();
+}
- const int64 product = first_value * second_value;
- if (product < 0) {
- return errors::InvalidArgument("Negative dimension size from multiplying ",
- first_value, " and ", second_value);
+Status InferenceContext::Max(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out) {
+ const int64 first_value = Value(first);
+ const int64 second_value = Value(second);
+ if (first_value == kUnknownDim || second_value == kUnknownDim) {
+ *out = UnknownDim();
+ } else {
+ if (first_value >= second_value) {
+ *out = first;
+ } else {
+ *out = MakeDim(second);
+ }
}
- *out = MakeDim(product);
return Status::OK();
}
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 1aa51f5017..f35c8a4c81 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -46,7 +46,7 @@ class Dimension {
class Shape {
private:
Shape();
- Shape(std::vector<const Dimension*> dims);
+ Shape(const std::vector<const Dimension*>& dims);
~Shape() {}
const int32 rank_;
@@ -61,13 +61,17 @@ class Shape {
struct DimensionOrConstant {
public:
// Intentionally not explicit.
- DimensionOrConstant(const Dimension* dim) : dim(dim) {}
+ DimensionOrConstant(const Dimension* dim);
// val must be non-negative or InferenceContext::kUnknownDim.
- DimensionOrConstant(int64 val) : val(val) {}
+ DimensionOrConstant(int64 val);
- const Dimension* dim = nullptr;
- int64 val = 0;
+ // dim takes precedence. If dim != nullptr, val is ignored.
+ const Dimension* dim;
+ int64 val;
+
+ private:
+ DimensionOrConstant();
};
// Note: This is experimental support for op shape inference in C++. Shape
@@ -81,8 +85,8 @@ struct DimensionOrConstant {
// by the InferenceContext.
class InferenceContext {
public:
- static constexpr int32 kUnknownRank = -1;
static constexpr int64 kUnknownDim = -1;
+ static constexpr int32 kUnknownRank = -1;
// This is a temporary constructor used for initial testing.
//
@@ -127,8 +131,12 @@ class InferenceContext {
}
int32 Rank(const Shape* s) { return s->rank_; }
bool RankKnown(const Shape* s) { return Rank(s) != kUnknownRank; }
- int64 Value(const Dimension* d) { return d->value_; }
- bool ValueKnown(const Dimension* d) { return Value(d) != kUnknownDim; }
+ inline int64 Value(DimensionOrConstant d) {
+ return d.dim ? d.dim->value_ : d.val;
+ }
+ inline bool ValueKnown(DimensionOrConstant d) {
+ return Value(d) != kUnknownDim;
+ }
// Returns true if the rank and all dimensions of the Shape are known.
bool FullyDefined(const Shape* s);
@@ -232,8 +240,15 @@ class InferenceContext {
// Returns a new dimension of the given size. The returned value is owned by
// this context.
- const Dimension* MakeDim(int64 value);
- const Dimension* UnknownDim();
+ inline const Dimension* MakeDim(DimensionOrConstant d) {
+ if (d.dim) {
+ return d.dim;
+ } else {
+ all_dims_.push_back(new Dimension(d.val));
+ return all_dims_.back();
+ }
+ }
+ inline const Dimension* UnknownDim() { return MakeDim(kUnknownDim); }
// Returns a new dimension whose value is given by a scalar input tensor.
// The input tensor must be in host memory, since it is dereferenced to get
@@ -247,7 +262,8 @@ class InferenceContext {
Status GetAttr(StringPiece attr_name, T* value) const;
// Returns in <out> the result of dividing <dividend> by <divisor>.
- // Returns an error if <divisor> does not evenly divide <dividend>.
+ // Returns an error if <divisor> is not positive or does not evenly
+ // divide <dividend>.
Status Divide(const Dimension* dividend, int64 divisor,
const Dimension** out);
@@ -255,10 +271,25 @@ class InferenceContext {
Status Add(const Dimension* first, DimensionOrConstant second,
const Dimension** out);
+ // Returns in <out> the dimension that is <first> minus <second>.
+ Status Subtract(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out);
+
// Returns in <out> the product of <first> and <second>.
Status Multiply(const Dimension* first, DimensionOrConstant second,
const Dimension** out);
+ // Returns in <out> the minimum of <first> and <second>. If either <first> or
+ // <second> is zero the results is zero. Otherwise, if either <first> or
+ // <second> is unknown the results is unknown.
+ Status Min(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out);
+
+ // Returns in <out> the maximum of <first> and <second>. If either <first> or
+ // <second> is unknown the results is unknown.
+ Status Max(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out);
+
Status construction_status() const { return construction_status_; }
// Validates that 'dim' has a known value, and prints an error
@@ -307,12 +338,30 @@ class InferenceContext {
// Template and inline method implementations, please ignore
inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
-inline Dimension::Dimension(int64 value) : value_(value) {}
+inline Dimension::Dimension(int64 value) : value_(value) {
+ DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
+ << "Dimension must be non-negative or equal to "
+ "InferenceContext::kUnknownDim but got"
+ << value;
+}
inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
-inline Shape::Shape(const std::vector<const Dimension*> dims)
+inline Shape::Shape(const std::vector<const Dimension*>& dims)
: rank_(dims.size()), dims_(dims) {}
+inline DimensionOrConstant::DimensionOrConstant(const Dimension* dim)
+ : dim(dim) {
+ DCHECK(dim != nullptr) << "Internal error: Got nullptr for Dimension.";
+}
+
+inline DimensionOrConstant::DimensionOrConstant(int64 val)
+ : dim(nullptr), val(val) {
+ DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
+ << "Dimension must be non-negative or equal to "
+ "InferenceContext::kUnknownDim but got"
+ << val;
+}
+
template <class T>
Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
return GetNodeAttr(node_def_, attr_name, value);
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index fffb25da6d..1ecba2839a 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -36,6 +36,19 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
return op_reg_data.op_def;
}
+TEST(ShapeInferenceTest, DimensionOrConstant) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(1, 1), {"?"}, {});
+ EXPECT_EQ(InferenceContext::kUnknownDim,
+ c.Value(InferenceContext::kUnknownDim));
+ EXPECT_EQ(1, c.Value(1));
+
+#ifndef NDEBUG
+ // Only run death test if DCHECKS are enabled.
+ EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to");
+#endif
+}
+
TEST(ShapeInferenceTest, RankAndDimInspection) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {"?", "[1,?,3]", "[]"}, {});
@@ -767,15 +780,20 @@ TEST(ShapeInferenceTest, Divide) {
EXPECT_EQ("Dimension size must be divisible by 5 but is 6",
c.Divide(d_6, 5, &out).error_message());
+ EXPECT_EQ("Divisor must be positive but is 0",
+ c.Divide(d_6, 0, &out).error_message());
+ EXPECT_EQ("Divisor must be positive but is -1",
+ c.Divide(d_6, -1, &out).error_message());
}
TEST(ShapeInferenceTest, Add) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?]"}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0]"}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
auto d_unknown = c.Dim(s, 1);
+ auto d_0 = c.Dim(s, 2);
// Adding non-zero to unknown gives new unknown.
const Dimension* out;
@@ -790,16 +808,14 @@ TEST(ShapeInferenceTest, Add) {
EXPECT_TRUE(out == d_6);
// Adding dimension with value 0 to anything gives input.
- EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0), &out).ok());
+ EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok());
EXPECT_TRUE(out == d_unknown);
- EXPECT_TRUE(c.Add(d_6, c.MakeDim(0), &out).ok());
+ EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok());
EXPECT_TRUE(out == d_6);
// Test addition.
EXPECT_TRUE(c.Add(d_6, 2, &out).ok());
EXPECT_EQ("8", c.DebugString(out));
- EXPECT_TRUE(c.Add(d_6, -6, &out).ok());
- EXPECT_EQ("0", c.DebugString(out));
EXPECT_TRUE(c.Add(d_6, std::numeric_limits<int64>::max() - 6, &out).ok());
EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
@@ -811,14 +827,62 @@ TEST(ShapeInferenceTest, Add) {
EXPECT_EQ(std::numeric_limits<int64>::max(), c.Value(out));
EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok());
EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Add(d_0, d_6, &out).ok());
+ EXPECT_TRUE(out == d_6);
- EXPECT_EQ("Negative dimension size from adding 6 and -7",
- c.Add(d_6, -7, &out).error_message());
EXPECT_EQ(
"Dimension size overflow from adding 6 and 9223372036854775802",
c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message());
}
+TEST(ShapeInferenceTest, Subtract) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0,5]"}, {});
+
+ auto s = c.input(0);
+ auto d_6 = c.Dim(s, 0);
+ auto d_unknown = c.Dim(s, 1);
+ auto d_0 = c.Dim(s, 2);
+ auto d_5 = c.Dim(s, 3);
+
+ // Subtracting non-zero from unknown gives new unknown.
+ const Dimension* out;
+ EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(out != d_unknown);
+
+ // Subtracting 0 from anything gives input.
+ EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok());
+ EXPECT_TRUE(out == d_unknown);
+ EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok());
+ EXPECT_TRUE(out == d_6);
+
+ // Subtracting dimension with value 0 from anything gives input.
+ EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok());
+ EXPECT_TRUE(out == d_unknown);
+ EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok());
+ EXPECT_TRUE(out == d_6);
+
+ // Test subtraction.
+ EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok());
+ EXPECT_EQ("4", c.DebugString(out));
+ EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok());
+ EXPECT_EQ("0", c.DebugString(out));
+
+ // Test subtraction using dimension as second value.
+ EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok());
+ EXPECT_EQ("4", c.DebugString(out));
+ EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok());
+ EXPECT_EQ("1", c.DebugString(out));
+ EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok());
+ EXPECT_TRUE(out == d_6);
+
+ EXPECT_EQ("Negative dimension size caused by subtracting 6 from 5",
+ c.Subtract(d_5, d_6, &out).error_message());
+}
+
TEST(ShapeInferenceTest, Multiply) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0,1]"}, {});
@@ -831,7 +895,7 @@ TEST(ShapeInferenceTest, Multiply) {
// Multiplying non-zero to unknown gives new unknown.
const Dimension* out;
- EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok());
+ EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
// Multiplying 0 to anything gives 0.
@@ -844,19 +908,19 @@ TEST(ShapeInferenceTest, Multiply) {
// Multiplying 1 to anything gives the original.
// (unknown -> unknown)
- EXPECT_TRUE(c.Multiply(d_unknown, static_cast<int64>(1), &out).ok());
- EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok());
+ EXPECT_EQ(d_unknown, out);
EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok());
- EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_EQ(d_unknown, out);
EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok());
- EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_EQ(d_unknown, out);
// (known -> known)
- EXPECT_TRUE(c.Multiply(d_6, static_cast<int64>(1), &out).ok());
- EXPECT_EQ("6", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok());
+ EXPECT_EQ(d_6, out);
EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok());
- EXPECT_EQ("6", c.DebugString(out));
+ EXPECT_EQ(d_6, out);
EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok());
- EXPECT_EQ("6", c.DebugString(out));
+ EXPECT_EQ(d_6, out);
// Test multiplication.
EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok());
@@ -869,9 +933,6 @@ TEST(ShapeInferenceTest, Multiply) {
EXPECT_EQ("12", c.DebugString(out));
EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok());
EXPECT_EQ("?", c.DebugString(out));
-
- EXPECT_EQ("Negative dimension size from multiplying 6 and -7",
- c.Multiply(d_6, -7, &out).error_message());
}
TEST(ShapeInferenceTest, FullyDefined) {
@@ -895,5 +956,90 @@ TEST(ShapeInferenceTest, ValidateKnownDim) {
EXPECT_TRUE(c.ValidateKnownDim(c.Dim(c.Matrix(1, 2), 0), "known").ok());
}
+TEST(ShapeInferenceTest, Min) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(1, 2), {"[1,2,?,0]"}, {});
+
+ auto s = c.input(0);
+ auto d_1 = c.Dim(s, 0);
+ auto d_2 = c.Dim(s, 1);
+ auto d_unknown = c.Dim(s, 2);
+ auto d_0 = c.Dim(s, 3);
+
+ // Minimum involving zero and unknown returns zero.
+ const Dimension* out;
+ EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok());
+ EXPECT_EQ(d_0, out);
+ EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok());
+ EXPECT_EQ(d_0, out);
+ EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok());
+ EXPECT_EQ("0", c.DebugString(out));
+ EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok());
+ EXPECT_EQ("0", c.DebugString(out));
+
+ // Minimum involving unknowns and non-zeros gives new unknown.
+ EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+
+ // Minimum with constant second arg.
+ EXPECT_TRUE(c.Min(d_1, 1, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Min(d_1, 3, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Min(d_2, 1, &out).ok());
+ EXPECT_EQ("1", c.DebugString(out));
+
+ // Minimum with two dimensions.
+ EXPECT_TRUE(c.Min(d_1, d_1, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Min(d_1, d_2, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Min(d_2, d_1, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Min(d_2, d_2, &out).ok());
+ EXPECT_EQ(d_2, out);
+}
+
+TEST(ShapeInferenceTest, Max) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(1, 2), {"[1,2,?]"}, {});
+
+ auto s = c.input(0);
+ auto d_1 = c.Dim(s, 0);
+ auto d_2 = c.Dim(s, 1);
+ auto d_unknown = c.Dim(s, 2);
+
+ // Maximum involving unknowns gives new unknown.
+ const Dimension* out;
+ EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+
+ // Maximum with constant second arg.
+ EXPECT_TRUE(c.Max(d_1, 1, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Max(d_2, 1, &out).ok());
+ EXPECT_EQ(d_2, out);
+ EXPECT_TRUE(c.Max(d_2, 3, &out).ok());
+ EXPECT_EQ("3", c.DebugString(out));
+
+ // Maximum with two dimensions.
+ EXPECT_TRUE(c.Max(d_1, d_1, &out).ok());
+ EXPECT_EQ(d_1, out);
+ EXPECT_TRUE(c.Max(d_1, d_2, &out).ok());
+ EXPECT_EQ(d_2, out);
+ EXPECT_TRUE(c.Max(d_2, d_1, &out).ok());
+ EXPECT_EQ(d_2, out);
+ EXPECT_TRUE(c.Max(d_2, d_2, &out).ok());
+ EXPECT_EQ(d_2, out);
+}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index c1e55d032d..60a9cb101f 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -40,6 +40,11 @@ Status InferShapes(ShapeInferenceTestOp op, const string& ins,
shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def, ins_v,
op.input_tensors);
TF_RETURN_IF_ERROR(c.construction_status());
+ if (op_reg_data->shape_inference_fn == nullptr) {
+ return errors::InvalidArgument(
+ "No shape inference function exists for op '", op.name,
+ "', did you forget to define it?");
+ }
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
const int num_outputs = c.num_outputs();
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index f10bccd87c..a696888867 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2061,13 +2061,14 @@ REGISTER_OP("MirrorPadGrad")
auto paddings_data = paddings_t->matrix<int32>();
std::vector<const Dimension*> dims(input_rank);
for (int i = 0; i < input_rank; ++i) {
- const int32 pad0 = paddings_data(i, 0);
- const int32 pad1 = paddings_data(i, 1);
+ const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
+ const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
if (pad0 < 0 || pad1 < 0) {
return errors::InvalidArgument("Paddings must be non-negative");
}
- TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), -(pad0 + pad1), &dims[i]));
+ TF_RETURN_IF_ERROR(
+ c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
}
c->set_output(0, c->MakeShape(dims));
return Status::OK();
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index 0ea31ddca3..4686fa4b9f 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -115,6 +115,67 @@ Status BatchMatrixSolveShapeFn(InferenceContext* c, bool square) {
return Status::OK();
}
+Status BatchSvdShapeHelperFn(InferenceContext* c, const Shape* input) {
+ const Dimension* m = c->Dim(input, -2);
+ const Dimension* n = c->Dim(input, -1);
+ const Dimension* p;
+ TF_RETURN_IF_ERROR(c->Min(m, n, &p));
+ const Shape* batch_shape;
+ TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
+ const Shape* e_shape;
+ TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape));
+ c->set_output(0, e_shape);
+ bool compute_uv;
+ TF_RETURN_IF_ERROR(c->GetAttr("compute_uv", &compute_uv));
+ if (compute_uv) {
+ const Shape* u_shape;
+ const Shape* v_shape;
+ bool full_matrices;
+ TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
+ if (full_matrices) {
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(batch_shape, c->Matrix(m, m), &u_shape));
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
+ } else {
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(batch_shape, c->Matrix(m, p), &u_shape));
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(batch_shape, c->Matrix(n, p), &v_shape));
+ }
+ c->set_output(1, u_shape);
+ c->set_output(2, v_shape);
+ } else {
+ c->set_output(1, c->Vector(0ll));
+ c->set_output(2, c->Vector(0ll));
+ }
+ return Status::OK();
+}
+
+// Input is [M,N]. First output is [min(M,N)].
+// Second and third outputs are:
+// [0]; [0], if compute_uv is false.
+// [M,M]; [N,N], if compute_uv is true and full_matrices is true,
+// [M,P]; [N,P], if compute_uv is true and full_matrices is false,
+// where P = min(M,N).
+Status SvdShapeFn(InferenceContext* c) {
+ const Shape* input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
+ return BatchSvdShapeHelperFn(c, input);
+}
+
+// Input is [...,M,N]. First output is [...,min(M,N)].
+// Second and third outputs are:
+// [0]; [0], if compute_uv is false.
+// [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true,
+// [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false,
+// where P = min(M,N).
+Status BatchSvdShapeFn(InferenceContext* c) {
+ const Shape* input;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
+ return BatchSvdShapeHelperFn(c, input);
+}
+
} // namespace
REGISTER_OP("MatrixDeterminant")
@@ -258,9 +319,9 @@ Iain Murray http://arxiv.org/abs/1602.07527.
l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`.
Algorithm depends only on lower triangular part of this matrix.
-grad: df/dl where f is some scalar function. Shape is `[M, M]'.
+grad: df/dl where f is some scalar function. Shape is `[M, M]`.
Algorithm depends only on lower triangular part of this matrix.
-output: Symmetrized version of df/dA . Shape is `[M, M]'.
+output: Symmetrized version of df/dA . Shape is `[M, M]`.
)doc");
REGISTER_OP("BatchCholeskyGrad")
@@ -278,10 +339,10 @@ Iain Murray http://arxiv.org/abs/1602.07527.
l: Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`.
Algorithm depends only on lower triangular part of the innermost matrices of
this tensor.
-grad: df/dl where f is some scalar function. Shape is `[..., M, M]'.
+grad: df/dl where f is some scalar function. Shape is `[..., M, M]`.
Algorithm depends only on lower triangular part of the innermost matrices of
this tensor.
-output: Symmetrized version of df/dA . Shape is `[..., M, M]'
+output: Symmetrized version of df/dA . Shape is `[..., M, M]`
)doc");
REGISTER_OP("SelfAdjointEig")
@@ -571,6 +632,7 @@ REGISTER_OP("Svd")
.Attr("compute_uv: bool = False")
.Attr("full_matrices: bool = False")
.Attr("T: {double, float}")
+ .SetShapeFn(SvdShapeFn)
.Doc(R"doc(
Computes the singular value decomposition of a matrix.
@@ -609,6 +671,7 @@ REGISTER_OP("BatchSvd")
.Attr("compute_uv: bool = False")
.Attr("full_matrices: bool = False")
.Attr("T: {double, float}")
+ .SetShapeFn(BatchSvdShapeFn)
.Doc(R"doc(
Computes the singular value decompositions of a batch of matrices.
diff --git a/tensorflow/core/ops/linalg_ops_test.cc b/tensorflow/core/ops/linalg_ops_test.cc
index 84e888bb9c..bc95afaa37 100644
--- a/tensorflow/core/ops/linalg_ops_test.cc
+++ b/tensorflow/core/ops/linalg_ops_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/platform/test.h"
@@ -200,4 +201,100 @@ TEST(LinalgOpsTest, BatchMatrixSolveLs_ShapeFn) {
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "?;[1];?");
}
+TEST(LinalgOpsTest, Svd_ShapeFn) {
+ ShapeInferenceTestOp op("Svd");
+ auto set_attrs = [&op](bool compute_uv, bool full_matrices) {
+ TF_CHECK_OK(NodeDefBuilder("test", "Svd")
+ .Input({"input", 0, DT_FLOAT})
+ .Attr("compute_uv", compute_uv)
+ .Attr("full_matrices", full_matrices)
+ .Finalize(&op.node_def));
+ };
+
+ set_attrs(false, false);
+ INFER_OK(op, "?", "[?];[0];[0]");
+ INFER_OK(op, "[?,?]", "[?];[0];[0]");
+ INFER_OK(op, "[2,?]", "[?];[0];[0]");
+ INFER_OK(op, "[?,2]", "[?];[0];[0]");
+ INFER_OK(op, "[2,2]", "[d0_0];[0];[0]");
+ INFER_OK(op, "[3,2]", "[d0_1];[0];[0]");
+ INFER_OK(op, "[2,3]", "[d0_0];[0];[0]");
+ INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1]");
+ INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3]");
+
+ set_attrs(true, false);
+ INFER_OK(op, "?", "[?];[?,?];[?,?]");
+ INFER_OK(op, "[?,?]", "[?];[d0_0,?];[d0_1,?]");
+ INFER_OK(op, "[2,?]", "[?];[d0_0,?];[d0_1,?]");
+ INFER_OK(op, "[?,2]", "[?];[d0_0,?];[d0_1,?]");
+ INFER_OK(op, "[2,2]", "[d0_0];[d0_0,d0_0];[d0_1,d0_0]");
+ INFER_OK(op, "[3,2]", "[d0_1];[d0_0,d0_1];[d0_1,d0_1]");
+ INFER_OK(op, "[2,3]", "[d0_0];[d0_0,d0_0];[d0_1,d0_0]");
+ INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1]");
+ INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3]");
+
+ set_attrs(true, true);
+ INFER_OK(op, "?", "[?];[?,?];[?,?]");
+ INFER_OK(op, "[?,?]", "[?];[d0_0,d0_0];[d0_1,d0_1]");
+ INFER_OK(op, "[2,?]", "[?];[d0_0,d0_0];[d0_1,d0_1]");
+ INFER_OK(op, "[?,2]", "[?];[d0_0,d0_0];[d0_1,d0_1]");
+ INFER_OK(op, "[2,2]", "[d0_0];[d0_0,d0_0];[d0_1,d0_1]");
+ INFER_OK(op, "[3,2]", "[d0_1];[d0_0,d0_0];[d0_1,d0_1]");
+ INFER_OK(op, "[2,3]", "[d0_0];[d0_0,d0_0];[d0_1,d0_1]");
+ INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1]");
+ INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3]");
+}
+
+TEST(LinalgOpsTest, BatchSvd_ShapeFn) {
+ ShapeInferenceTestOp op("BatchSvd");
+ auto set_attrs = [&op](bool compute_uv, bool full_matrices) {
+ TF_CHECK_OK(NodeDefBuilder("test", "BatchSvd")
+ .Input({"input", 0, DT_FLOAT})
+ .Attr("compute_uv", compute_uv)
+ .Attr("full_matrices", full_matrices)
+ .Finalize(&op.node_def));
+ };
+ set_attrs(false, false);
+ INFER_OK(op, "?", "?;[0];[0]");
+ INFER_OK(op, "[?,?,?]", "[d0_0,?];[0];[0]");
+ INFER_OK(op, "[4,?,?]", "[d0_0,?];[0];[0]");
+ INFER_OK(op, "[4,2,?]", "[d0_0,?];[0];[0]");
+ INFER_OK(op, "[4,?,2]", "[d0_0,?];[0];[0]");
+ INFER_OK(op, "[?,2,2]", "[d0_0,d0_1];[0];[0]");
+ INFER_OK(op, "[4,2,2]", "[d0_0,d0_1];[0];[0]");
+ INFER_OK(op, "[?,3,2]", "[d0_0,d0_2];[0];[0]");
+ INFER_OK(op, "[4,3,2]", "[d0_0,d0_2];[0];[0]");
+ INFER_OK(op, "[?,2,3]", "[d0_0,d0_1];[0];[0]");
+ INFER_OK(op, "[4,2,3]", "[d0_0,d0_1];[0];[0]");
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
+
+ set_attrs(true, false);
+ INFER_OK(op, "?", "?;?;?");
+ INFER_OK(op, "[?,?,?]", "[d0_0,?];[d0_0,d0_1,?];[d0_0,d0_2,?]");
+ INFER_OK(op, "[4,?,?]", "[d0_0,?];[d0_0,d0_1,?];[d0_0,d0_2,?]");
+ INFER_OK(op, "[4,2,?]", "[d0_0,?];[d0_0,d0_1,?];[d0_0,d0_2,?]");
+ INFER_OK(op, "[4,?,2]", "[d0_0,?];[d0_0,d0_1,?];[d0_0,d0_2,?]");
+ INFER_OK(op, "[?,2,2]", "[d0_0,d0_1];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_1]");
+ INFER_OK(op, "[4,2,2]", "[d0_0,d0_1];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_1]");
+ INFER_OK(op, "[?,3,2]", "[d0_0,d0_2];[d0_0,d0_1,d0_2];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[4,3,2]", "[d0_0,d0_2];[d0_0,d0_1,d0_2];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[?,2,3]", "[d0_0,d0_1];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_1]");
+ INFER_OK(op, "[4,2,3]", "[d0_0,d0_1];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_1]");
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
+
+ set_attrs(true, true);
+ INFER_OK(op, "?", "?;?;?");
+ INFER_OK(op, "[?,?,?]", "[d0_0,?];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[4,?,?]", "[d0_0,?];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[4,2,?]", "[d0_0,?];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[4,?,2]", "[d0_0,?];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[?,2,2]", "[d0_0,d0_1];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[4,2,2]", "[d0_0,d0_1];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[?,3,2]", "[d0_0,d0_2];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[4,3,2]", "[d0_0,d0_2];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[?,2,3]", "[d0_0,d0_1];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_2]");
+ INFER_OK(op, "[4,2,3]", "[d0_0,d0_1];[d0_0,d0_1,d0_1];[d0_0,d0_2,d0_2]");
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index ac21338505..17d5983d76 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -363,7 +363,7 @@ REGISTER_OP("SparseConcat")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
// These accumulates the sum.
- const Dimension* output_row_count = c->MakeDim(0);
+ const Dimension* output_row_count = c->MakeDim(0ll);
// These are only merged.
const Dimension* output_ind_cols = c->UnknownDim();