diff options
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 102 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.h | 3 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns_test.cc | 29 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 61 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 4 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference_test.cc | 55 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 1 |
7 files changed, 239 insertions, 16 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 1fcee93a86..dc03b68ec2 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -173,6 +173,17 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) { return Status::OK(); } +namespace { +Status CheckKnownDim(shape_inference::InferenceContext* c, const Dimension* dim, + const char* name) { + if (!c->ValueKnown(dim)) { + return errors::InvalidArgument("Cannot infer shape because dimension ", + name, " is not known."); + } + return Status::OK(); +} +} // namespace + Status Conv2DShape(shape_inference::InferenceContext* c) { const Shape* input_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); @@ -213,17 +224,10 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { const Dimension* output_depth_dim = c->Dim(filter_shape, 3); // At the moment we need to know the values of several fields. - auto CheckKnownDim = [&c](const Dimension* dim, const char* name) { - if (!c->ValueKnown(dim)) { - return errors::InvalidArgument("Cannot infer shape because dimension ", - name, " is not known."); - } - return Status::OK(); - }; - TF_RETURN_IF_ERROR(CheckKnownDim(in_rows_dim, "in_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(in_cols_dim, "in_cols")); - TF_RETURN_IF_ERROR(CheckKnownDim(filter_rows_dim, "filter_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(filter_cols_dim, "filter_cols")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols")); auto in_rows = c->Value(in_rows_dim); auto in_cols = c->Value(in_cols_dim); @@ -248,14 +252,80 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { const Shape* output_shape; if (data_format == "NCHW") { - output_shape = - c->MakeShape({{batch_size_dim, output_depth_dim, - c->MakeDim(output_rows), c->MakeDim(output_cols)}}); + output_shape = c->MakeShape( + {batch_size_dim, output_depth_dim, output_rows, output_cols}); } else { - output_shape = c->MakeShape({{batch_size_dim, c->MakeDim(output_rows), - c->MakeDim(output_cols), output_depth_dim}}); + output_shape = c->MakeShape( + {batch_size_dim, output_rows, output_cols, output_depth_dim}); + } + + c->set_output(0, output_shape); + return Status::OK(); +} + +Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { + const Shape* input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + const Shape* filter_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); + + std::vector<int32> strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + + if (strides.size() != 4) { + return errors::InvalidArgument( + "Conv2D requires the stride attribute to contain 4 values, but got: ", + strides.size()); } + const Dimension* batch_size_dim = c->Dim(input_shape, 0); + const Dimension* in_rows_dim = c->Dim(input_shape, 1); + const Dimension* in_cols_dim = c->Dim(input_shape, 2); + const Dimension* filter_rows_dim = c->Dim(filter_shape, 0); + const Dimension* filter_cols_dim = c->Dim(filter_shape, 1); + const Dimension* input_depth = c->Dim(filter_shape, 2); + const Dimension* depth_multiplier = c->Dim(filter_shape, 3); + + // At the moment we need to know the values of several fields. + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, input_depth, "depth")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, depth_multiplier, "depth_multiplier")); + + // Check that the input depths are compatible. + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth)); + + const Dimension* output_depth; + TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth)); + + const int32 stride_rows = strides[1]; + const int32 stride_cols = strides[2]; + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + // TODO(mrry,shlens): Raise an error if the stride would cause + // information in the input to be ignored. This will require a change + // in the kernel implementation. + auto in_rows = c->Value(in_rows_dim); + auto in_cols = c->Value(in_cols_dim); + auto filter_rows = c->Value(filter_rows_dim); + auto filter_cols = c->Value(filter_cols_dim); + + int64 output_rows, output_cols; + int64 padding_before, padding_after; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_rows, filter_rows, stride_rows, padding, &output_rows, &padding_before, + &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_cols, filter_cols, stride_cols, padding, &output_cols, &padding_before, + &padding_after)); + + const Shape* output_shape = + c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); c->set_output(0, output_shape); return Status::OK(); } diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 20410d87be..1926277395 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -157,6 +157,9 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c); // Shape function for Conv2D-like operations. Status Conv2DShape(shape_inference::InferenceContext* c); +// Shape function for DepthwiseConv2D-like operations. +Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index a8ba18fe42..a8a692cf08 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -419,5 +419,34 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) { INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,4,4,d1_3]"); } +TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) { + ShapeInferenceTestOp op("DepthwiseConv2dNative"); + std::vector<int32> strides = {{1, 1, 1, 1}}; + TF_CHECK_OK(NodeDefBuilder("test", "DepthwiseConv2dNative") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("padding", "VALID") + .Finalize(&op.node_def)); + + // Most of DepthwiseConv2D is implicitly tested by Conv2D, so + // we test only the very-specific differences here. + + // 1x1 filter, depth multiplication + INFER_OK(op, "[1,2,2,3];[1,1,3,4]", "[d0_0,2,2,12]"); + + // Input depths not compatible + INFER_ERROR("Dimensions must be equal, but are 3 and 12", op, + "[1,2,2,3];[1,1,12,4]"); + + // No unknown dims in the critical fields. + INFER_ERROR("is not known", op, "[1,?,2,1];[1,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,?,1];[1,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,2,1];[?,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,2,1];[1,?,1,1]"); + INFER_ERROR("is not known", op, "[1,2,2,1];[1,1,?,1]"); + INFER_ERROR("is not known", op, "[1,2,2,1];[1,1,1,?]"); +} + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index f2ecc773fa..023174e83a 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -519,5 +519,66 @@ Status InferenceContext::Add(const Dimension* first, DimensionOrConstant second, return Status::OK(); } +Status InferenceContext::Multiply(const Dimension* first, + DimensionOrConstant second, + const Dimension** out) { + int64 first_value = -1; + // Special cases for multiply are when the values are 0 or 1. + if (ValueKnown(first)) { + first_value = Value(first); + if (first_value == 0) { + *out = MakeDim(0); + return Status::OK(); + } + + // Output is whatever the second value is. + if (first_value == 1) { + *out = GetDimension(second); + return Status::OK(); + } + } + + // Same check for when the second argument is a known value. + // First find out if the value is known from DimOrConstant. + int64 second_value; + if (second.dim == nullptr) { + second_value = second.val; + } else { + if (!ValueKnown(second.dim)) { + // Second value is not known and first is not a special caase + *out = UnknownDim(); + return Status::OK(); + } + second_value = Value(second.dim); + } + + // Now that we know whether the value is known, apply the special + // casing. + if (second_value == 0) { + *out = MakeDim(0); + return Status::OK(); + } + + // Output is whatever the first value is. + if (second_value == 1) { + *out = first; + return Status::OK(); + } + + if (!ValueKnown(first)) { + // First value is not known and second is not a special caase + *out = UnknownDim(); + return Status::OK(); + } + + const int64 product = first_value * second_value; + if (product < 0) { + return errors::InvalidArgument("Negative dimension size from multiplying ", + first_value, " and ", second_value); + } + *out = MakeDim(product); + return Status::OK(); +} + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 1bdc74dc5a..e3ee40bb6f 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -248,6 +248,10 @@ class InferenceContext { Status Add(const Dimension* first, DimensionOrConstant second, const Dimension** out); + // Returns in <out> the product of <first> and <second>. + Status Multiply(const Dimension* first, DimensionOrConstant second, + const Dimension** out); + private: const Dimension* GetDimension(const DimensionOrConstant& d); diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index ce175de561..2a40e962ea 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -806,5 +806,60 @@ TEST(ShapeInferenceTest, Add) { c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message()); } +TEST(ShapeInferenceTest, Multiply) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(1, 2), {"[6,?,0,1]"}, {}); + + auto s = c.input(0); + auto d_6 = c.Dim(s, 0); + auto d_unknown = c.Dim(s, 1); + auto d_0 = c.Dim(s, 2); + auto d_1 = c.Dim(s, 3); + + // Multiplying non-zero to unknown gives new unknown. + const Dimension* out; + EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + + // Multiplying 0 to anything gives 0. + EXPECT_TRUE(c.Multiply(d_unknown, static_cast<int64>(0), &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_unknown, d_0, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_0, d_unknown, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + + // Multiplying 1 to anything gives the original. + // (unknown -> unknown) + EXPECT_TRUE(c.Multiply(d_unknown, static_cast<int64>(1), &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + // (known -> known) + EXPECT_TRUE(c.Multiply(d_6, static_cast<int64>(1), &out).ok()); + EXPECT_EQ("6", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok()); + EXPECT_EQ("6", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok()); + EXPECT_EQ("6", c.DebugString(out)); + + // Test multiplication. + EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok()); + EXPECT_EQ("12", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_6, 6, &out).ok()); + EXPECT_EQ("36", c.DebugString(out)); + + // Test multiplication using dimension as second value. + EXPECT_TRUE(c.Multiply(d_6, c.MakeDim(2), &out).ok()); + EXPECT_EQ("12", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + + EXPECT_EQ("Negative dimension size from multiplying 6 and -7", + c.Multiply(d_6, -7, &out).error_message()); +} + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index e24f9e0fb2..59ffcf430c 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -432,6 +432,7 @@ REGISTER_OP("DepthwiseConv2dNative") .Attr("T: {float, double}") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) + .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape) .Doc(R"doc( Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. |