aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc102
-rw-r--r--tensorflow/core/framework/common_shape_fns.h3
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc29
-rw-r--r--tensorflow/core/framework/shape_inference.cc61
-rw-r--r--tensorflow/core/framework/shape_inference.h4
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc55
-rw-r--r--tensorflow/core/ops/nn_ops.cc1
7 files changed, 239 insertions, 16 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 1fcee93a86..dc03b68ec2 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -173,6 +173,17 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+namespace {
+Status CheckKnownDim(shape_inference::InferenceContext* c, const Dimension* dim,
+ const char* name) {
+ if (!c->ValueKnown(dim)) {
+ return errors::InvalidArgument("Cannot infer shape because dimension ",
+ name, " is not known.");
+ }
+ return Status::OK();
+}
+} // namespace
+
Status Conv2DShape(shape_inference::InferenceContext* c) {
const Shape* input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
@@ -213,17 +224,10 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
const Dimension* output_depth_dim = c->Dim(filter_shape, 3);
// At the moment we need to know the values of several fields.
- auto CheckKnownDim = [&c](const Dimension* dim, const char* name) {
- if (!c->ValueKnown(dim)) {
- return errors::InvalidArgument("Cannot infer shape because dimension ",
- name, " is not known.");
- }
- return Status::OK();
- };
- TF_RETURN_IF_ERROR(CheckKnownDim(in_rows_dim, "in_rows"));
- TF_RETURN_IF_ERROR(CheckKnownDim(in_cols_dim, "in_cols"));
- TF_RETURN_IF_ERROR(CheckKnownDim(filter_rows_dim, "filter_rows"));
- TF_RETURN_IF_ERROR(CheckKnownDim(filter_cols_dim, "filter_cols"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
auto in_rows = c->Value(in_rows_dim);
auto in_cols = c->Value(in_cols_dim);
@@ -248,14 +252,80 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
const Shape* output_shape;
if (data_format == "NCHW") {
- output_shape =
- c->MakeShape({{batch_size_dim, output_depth_dim,
- c->MakeDim(output_rows), c->MakeDim(output_cols)}});
+ output_shape = c->MakeShape(
+ {batch_size_dim, output_depth_dim, output_rows, output_cols});
} else {
- output_shape = c->MakeShape({{batch_size_dim, c->MakeDim(output_rows),
- c->MakeDim(output_cols), output_depth_dim}});
+ output_shape = c->MakeShape(
+ {batch_size_dim, output_rows, output_cols, output_depth_dim});
+ }
+
+ c->set_output(0, output_shape);
+ return Status::OK();
+}
+
+Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
+ const Shape* input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
+ const Shape* filter_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
+
+ std::vector<int32> strides;
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
+
+ if (strides.size() != 4) {
+ return errors::InvalidArgument(
+ "Conv2D requires the stride attribute to contain 4 values, but got: ",
+ strides.size());
}
+ const Dimension* batch_size_dim = c->Dim(input_shape, 0);
+ const Dimension* in_rows_dim = c->Dim(input_shape, 1);
+ const Dimension* in_cols_dim = c->Dim(input_shape, 2);
+ const Dimension* filter_rows_dim = c->Dim(filter_shape, 0);
+ const Dimension* filter_cols_dim = c->Dim(filter_shape, 1);
+ const Dimension* input_depth = c->Dim(filter_shape, 2);
+ const Dimension* depth_multiplier = c->Dim(filter_shape, 3);
+
+ // At the moment we need to know the values of several fields.
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, input_depth, "depth"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, depth_multiplier, "depth_multiplier"));
+
+ // Check that the input depths are compatible.
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
+
+ const Dimension* output_depth;
+ TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
+
+ const int32 stride_rows = strides[1];
+ const int32 stride_cols = strides[2];
+
+ Padding padding;
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
+
+ // TODO(mrry,shlens): Raise an error if the stride would cause
+ // information in the input to be ignored. This will require a change
+ // in the kernel implementation.
+ auto in_rows = c->Value(in_rows_dim);
+ auto in_cols = c->Value(in_cols_dim);
+ auto filter_rows = c->Value(filter_rows_dim);
+ auto filter_cols = c->Value(filter_cols_dim);
+
+ int64 output_rows, output_cols;
+ int64 padding_before, padding_after;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_rows, filter_rows, stride_rows, padding, &output_rows, &padding_before,
+ &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_cols, filter_cols, stride_cols, padding, &output_cols, &padding_before,
+ &padding_after));
+
+ const Shape* output_shape =
+ c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
c->set_output(0, output_shape);
return Status::OK();
}
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 20410d87be..1926277395 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -157,6 +157,9 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c);
// Shape function for Conv2D-like operations.
Status Conv2DShape(shape_inference::InferenceContext* c);
+// Shape function for DepthwiseConv2D-like operations.
+Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index a8ba18fe42..a8a692cf08 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -419,5 +419,34 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,4,4,d1_3]");
}
+TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) {
+ ShapeInferenceTestOp op("DepthwiseConv2dNative");
+ std::vector<int32> strides = {{1, 1, 1, 1}};
+ TF_CHECK_OK(NodeDefBuilder("test", "DepthwiseConv2dNative")
+ .Input("input", 0, DT_FLOAT)
+ .Input("filter", 0, DT_FLOAT)
+ .Attr("strides", strides)
+ .Attr("padding", "VALID")
+ .Finalize(&op.node_def));
+
+ // Most of DepthwiseConv2D is implicitly tested by Conv2D, so
+ // we test only the very-specific differences here.
+
+ // 1x1 filter, depth multiplication
+ INFER_OK(op, "[1,2,2,3];[1,1,3,4]", "[d0_0,2,2,12]");
+
+ // Input depths not compatible
+ INFER_ERROR("Dimensions must be equal, but are 3 and 12", op,
+ "[1,2,2,3];[1,1,12,4]");
+
+ // No unknown dims in the critical fields.
+ INFER_ERROR("is not known", op, "[1,?,2,1];[1,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,?,1];[1,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,2,1];[?,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,2,1];[1,?,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,2,1];[1,1,?,1]");
+ INFER_ERROR("is not known", op, "[1,2,2,1];[1,1,1,?]");
+}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index f2ecc773fa..023174e83a 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -519,5 +519,66 @@ Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second,
return Status::OK();
}
+Status InferenceContext::Multiply(const Dimension* first,
+ DimensionOrConstant second,
+ const Dimension** out) {
+ int64 first_value = -1;
+ // Special cases for multiply are when the values are 0 or 1.
+ if (ValueKnown(first)) {
+ first_value = Value(first);
+ if (first_value == 0) {
+ *out = MakeDim(0);
+ return Status::OK();
+ }
+
+ // Output is whatever the second value is.
+ if (first_value == 1) {
+ *out = GetDimension(second);
+ return Status::OK();
+ }
+ }
+
+ // Same check for when the second argument is a known value.
+ // First find out if the value is known from DimOrConstant.
+ int64 second_value;
+ if (second.dim == nullptr) {
+ second_value = second.val;
+ } else {
+ if (!ValueKnown(second.dim)) {
+ // Second value is not known and first is not a special caase
+ *out = UnknownDim();
+ return Status::OK();
+ }
+ second_value = Value(second.dim);
+ }
+
+ // Now that we know whether the value is known, apply the special
+ // casing.
+ if (second_value == 0) {
+ *out = MakeDim(0);
+ return Status::OK();
+ }
+
+ // Output is whatever the first value is.
+ if (second_value == 1) {
+ *out = first;
+ return Status::OK();
+ }
+
+ if (!ValueKnown(first)) {
+ // First value is not known and second is not a special caase
+ *out = UnknownDim();
+ return Status::OK();
+ }
+
+ const int64 product = first_value * second_value;
+ if (product < 0) {
+ return errors::InvalidArgument("Negative dimension size from multiplying ",
+ first_value, " and ", second_value);
+ }
+ *out = MakeDim(product);
+ return Status::OK();
+}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 1bdc74dc5a..e3ee40bb6f 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -248,6 +248,10 @@ class InferenceContext {
Status Add(const Dimension* first, DimensionOrConstant second,
const Dimension** out);
+ // Returns in <out> the product of <first> and <second>.
+ Status Multiply(const Dimension* first, DimensionOrConstant second,
+ const Dimension** out);
+
private:
const Dimension* GetDimension(const DimensionOrConstant& d);
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index ce175de561..2a40e962ea 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -806,5 +806,60 @@ TEST(ShapeInferenceTest, Add) {
c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message());
}
+TEST(ShapeInferenceTest, Multiply) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0,1]"}, {});
+
+ auto s = c.input(0);
+ auto d_6 = c.Dim(s, 0);
+ auto d_unknown = c.Dim(s, 1);
+ auto d_0 = c.Dim(s, 2);
+ auto d_1 = c.Dim(s, 3);
+
+ // Multiplying non-zero to unknown gives new unknown.
+ const Dimension* out;
+ EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+
+ // Multiplying 0 to anything gives 0.
+ EXPECT_TRUE(c.Multiply(d_unknown, static_cast<int64>(0), &out).ok());
+ EXPECT_EQ("0", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_unknown, d_0, &out).ok());
+ EXPECT_EQ("0", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_0, d_unknown, &out).ok());
+ EXPECT_EQ("0", c.DebugString(out));
+
+ // Multiplying 1 to anything gives the original.
+ // (unknown -> unknown)
+ EXPECT_TRUE(c.Multiply(d_unknown, static_cast<int64>(1), &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+ // (known -> known)
+ EXPECT_TRUE(c.Multiply(d_6, static_cast<int64>(1), &out).ok());
+ EXPECT_EQ("6", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok());
+ EXPECT_EQ("6", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok());
+ EXPECT_EQ("6", c.DebugString(out));
+
+ // Test multiplication.
+ EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok());
+ EXPECT_EQ("12", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_6, 6, &out).ok());
+ EXPECT_EQ("36", c.DebugString(out));
+
+ // Test multiplication using dimension as second value.
+ EXPECT_TRUE(c.Multiply(d_6, c.MakeDim(2), &out).ok());
+ EXPECT_EQ("12", c.DebugString(out));
+ EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok());
+ EXPECT_EQ("?", c.DebugString(out));
+
+ EXPECT_EQ("Negative dimension size from multiplying 6 and -7",
+ c.Multiply(d_6, -7, &out).error_message());
+}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index e24f9e0fb2..59ffcf430c 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -432,6 +432,7 @@ REGISTER_OP("DepthwiseConv2dNative")
.Attr("T: {float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
+ .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape)
.Doc(R"doc(
Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors.