From f3a613e9db95958316569d74748d4fdb632ffbb4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Jul 2016 16:22:02 -0800 Subject: Add C++ shape inference functions for more functions in nn_ops.cc. Add shape_inference::InferenceContext::ReplaceDim. Change: 127893881 --- tensorflow/core/framework/common_shape_fns.h | 8 + tensorflow/core/framework/shape_inference.cc | 11 ++ tensorflow/core/framework/shape_inference.h | 5 + tensorflow/core/framework/shape_inference_test.cc | 18 ++ tensorflow/core/ops/linalg_ops.cc | 9 +- tensorflow/core/ops/nn_ops.cc | 191 ++++++++++++++++++ tensorflow/core/ops/nn_ops_test.cc | 230 +++++++++++++++++++++- 7 files changed, 465 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 9bb329e520..2439751991 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -61,6 +61,14 @@ inline Status ScalarShape(shape_inference::InferenceContext* c) { return Status::OK(); } +// Shape function for binary ops where both inputs and the output match. +inline Status MergeBothInputsShapeFn(InferenceContext* c) { + const Shape* out; + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out)); + c->set_output(0, out); + return Status::OK(); +} + inline Status MatMulShape(shape_inference::InferenceContext* c) { const Shape* a; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a)); diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index c84856b058..1f5f3e9357 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -329,6 +329,17 @@ Status InferenceContext::Concatenate(const Shape* s1, const Shape* s2, return ReturnCreatedShape(dims, out); } +Status InferenceContext::ReplaceDim(const Shape* s, int dim_index, + const Dimension* new_dim, + const Shape** out) { + if (!RankKnown(s)) { + return ReturnUnknownShape(out); + } + std::vector dims(s->dims_); + dims[dim_index] = new_dim; + return ReturnCreatedShape(dims, out); +} + const Shape* InferenceContext::MakeShape( const std::vector& dims) { all_shapes_.push_back(new Shape(dims)); diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 4cb80233e9..eec4b0f263 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -192,6 +192,11 @@ class InferenceContext { Status Concatenate(const Shape* s1, const Shape* s2, const Shape** out) TF_MUST_USE_RESULT; + // Returns in the shape from replacing with + // . + Status ReplaceDim(const Shape* s, int dim_index, const Dimension* new_dim, + const Shape** out) TF_MUST_USE_RESULT; + // Returns a new shape with the given dims. The returned value is owned by // this context. const Shape* MakeShape(const std::vector& dims); diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 340672d5d1..323992ac5d 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -465,6 +465,24 @@ TEST(ShapeInferenceTest, Concatenate) { } } +TEST(ShapeInferenceTest, ReplaceDim) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(2, 0), {"[1,2,3]", "?"}, {}); + + auto in = c.input(0); + auto unknown = c.input(1); + + const Shape* replaced; + EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok()); + EXPECT_EQ("[2,2,3]", c.DebugString(replaced)); + EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok()); + EXPECT_EQ("[1,2,2]", c.DebugString(replaced)); + EXPECT_TRUE(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced).ok()); + EXPECT_EQ("[1,3,3]", c.DebugString(replaced)); + EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok()); + EXPECT_EQ("?", c.DebugString(replaced)); +} + TEST(ShapeInferenceTest, MakeShape) { NodeDef def; InferenceContext c(&def, MakeOpDef(1, 2), {"[1,2,3,?,5]"}, {}); diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 85442c7f66..25887ebd77 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -71,13 +71,12 @@ Status SquareMatrixSolveShapeFn(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &rhs)); // lhs and rhs have the same number of rows. Make a new output - // shape that has the merged-rows and the rest of the rhs. + // shape that uses rows to replace rhs.dim[0]. const Dimension* rows; TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, 0), c->Dim(rhs, 0), &rows)); - const Shape* rhs_remaining; - TF_RETURN_IF_ERROR(c->Subshape(rhs, 1, &rhs_remaining)); - TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(rows), rhs_remaining, &rhs)); - c->set_output(0, rhs); + const Shape* out; + TF_RETURN_IF_ERROR(c->ReplaceDim(rhs, 0, rows, &out)); + c->set_output(0, out); return Status::OK(); } diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 8311f47f3f..5671d042d5 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -25,6 +26,28 @@ typedef shape_inference::Dimension Dimension; typedef shape_inference::InferenceContext InferenceContext; typedef shape_inference::Shape Shape; +namespace { + +// A shape function that uses the tensor value at as a shape for +// output 0. If the tensor value is not available, it uses a shape with +// unknown dims. +Status InputTensorShapeOrUnknown(InferenceContext* c, int input_idx, + int ndims) { + const Shape* out; + const Tensor* input = c->input_tensor(input_idx); + if (input == nullptr) { + std::vector dims; + for (int i = 0; i < ndims; ++i) dims.push_back(c->UnknownDim()); + out = c->MakeShape(dims); + } else { + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(input_idx, &out)); + } + c->set_output(0, out); + return Status::OK(); +} + +} // namespace + // -------------------------------------------------------------------------- REGISTER_OP("AvgPool") @@ -62,6 +85,13 @@ REGISTER_OP("AvgPoolGrad") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) .Attr("T: {float, half, double}") + .SetShapeFn([](InferenceContext* c) { + // NOTE(mrry): We could in principle work out the shape from the + // gradients and the attrs, but if we do not know orig_input_shape + // statically, then we are unlikely to know the shape of the + // gradients either. + return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */); + }) .Doc(R"doc( Computes gradients of the average pooling function. @@ -92,6 +122,22 @@ REGISTER_OP("BatchNormWithGlobalNormalization") .Attr("variance_epsilon: float") .Attr("scale_after_normalization: bool") .Deprecated(9, "Use tf.nn.batch_normalization()") + .SetShapeFn([](InferenceContext* c) { + const Shape* input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + + const Dimension* last_dim = c->Dim(input, 3); + for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma + const Shape* vec; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); + TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim)); + } + + const Shape* out; + TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out)); + c->set_output(0, out); + return Status::OK(); + }) .Doc(R"doc( Batch normalization. @@ -129,6 +175,30 @@ REGISTER_OP("BatchNormWithGlobalNormalizationGrad") .Attr("variance_epsilon: float") .Attr("scale_after_normalization: bool") .Deprecated(9, "Use tf.nn.batch_normalization()") + .SetShapeFn([](InferenceContext* c) { + const Shape* input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + TF_RETURN_IF_ERROR( + c->Merge(input, c->input(4), &input)); // with backprop + + const Dimension* last_dim = c->Dim(input, 3); + for (int i = 1; i < 4; ++i) { // covers m, v, gamma + const Shape* vec; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); + TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim)); + } + + const Shape* dx; + TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx)); + c->set_output(0, dx); + + const Shape* vector_shape = c->Vector(last_dim); + c->set_output(1, vector_shape); + c->set_output(2, vector_shape); + c->set_output(3, vector_shape); + c->set_output(4, vector_shape); + return Status::OK(); + }) .Doc(R"doc( Gradients for batch normalization. @@ -280,6 +350,13 @@ REGISTER_OP("Conv2DBackpropInput") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) + .SetShapeFn([](InferenceContext* c) { + // NOTE(mrry): We could in principle work out the shape from the + // gradients and the attrs, but if we do not know orig_input_shape + // statically, then we are unlikely to know the shape of the + // gradients either. + return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */); + }) .Doc(R"doc( Computes the gradients of convolution with respect to the input. @@ -315,6 +392,13 @@ REGISTER_OP("Conv2DBackpropFilter") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) + .SetShapeFn([](InferenceContext* c) { + // NOTE(mrry): We could in principle work out the shape from the + // gradients and the attrs, but if we do not know orig_input_shape + // statically, then we are unlikely to know the shape of the + // gradients either. + return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */); + }) .Doc(R"doc( Computes the gradients of convolution with respect to the filter. @@ -380,6 +464,13 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropInput") .Attr("T: {float, double}") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + // NOTE(mrry): We could in principle work out the shape from the + // gradients and the attrs, but if we do not know orig_input_shape + // statically, then we are unlikely to know the shape of the + // gradients either. + return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */); + }) .Doc(R"doc( Computes the gradients of depthwise convolution with respect to the input. @@ -404,6 +495,13 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter") .Attr("T: {float, double}") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + // NOTE(mrry): We could in principle work out the shape from the + // gradients and the attrs, but if we do not know orig_input_shape + // statically, then we are unlikely to know the shape of the + // gradients either. + return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */); + }) .Doc(R"doc( Computes the gradients of depthwise convolution with respect to the filter. @@ -456,6 +554,9 @@ REGISTER_OP("Conv3DBackpropInput") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Deprecated(10, "Use Conv3DBackpropInputV2") + .SetShapeFn([](InferenceContext* c) { + return UnchangedShapeWithRank(c, 5); + }) .Doc(R"doc( Computes the gradients of 3-D convolution with respect to the input. @@ -479,6 +580,12 @@ REGISTER_OP("Conv3DBackpropFilter") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Deprecated(10, "Use Conv3DBackpropFilterV2") + .SetShapeFn([](InferenceContext* c) { + const Shape* out; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out)); + c->set_output(0, out); + return Status::OK(); + }) .Doc(R"doc( Computes the gradients of 3-D convolution with respect to the filter. @@ -501,6 +608,13 @@ REGISTER_OP("Conv3DBackpropInputV2") .Attr("T: numbertype") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + const Shape* s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); + c->set_output(0, s); + return Status::OK(); + }) .Doc(R"doc( Computes the gradients of 3-D convolution with respect to the input. @@ -525,6 +639,13 @@ REGISTER_OP("Conv3DBackpropFilterV2") .Attr("T: numbertype") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + const Shape* s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); + TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); + c->set_output(0, s); + return Status::OK(); + }) .Doc(R"doc( Computes the gradients of 3-D convolution with respect to the filter. @@ -570,6 +691,13 @@ REGISTER_OP("AvgPool3DGrad") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr("T: numbertype") + .SetShapeFn([](InferenceContext* c) { + const Shape* s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); + c->set_output(0, s); + return Status::OK(); + }) .Doc(R"doc( Computes gradients of average pooling function. @@ -613,6 +741,9 @@ REGISTER_OP("MaxPool3DGrad") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr("T: numbertype") + .SetShapeFn([](InferenceContext* c) { + return UnchangedShapeWithRank(c, 5); + }) .Doc(R"doc( Computes gradients of max pooling function. @@ -686,6 +817,14 @@ REGISTER_OP("LRNGrad") .Attr("alpha: float = 1.0") .Attr("beta: float = 0.5") .Attr("T: {float, half} = DT_FLOAT") + .SetShapeFn([](InferenceContext* c) { + const Shape* s; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads + TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image + TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image + c->set_output(0, s); + return Status::OK(); + }) .Doc(R"doc( Gradients for Local Response Normalization. @@ -735,6 +874,9 @@ REGISTER_OP("MaxPoolGrad") .Input("grad: T") .Output("output: T") .Attr("T: {float, half} = DT_FLOAT") + .SetShapeFn([](InferenceContext* c) { + return UnchangedShapeWithRank(c, 4); + }) .Doc(R"doc( Computes gradients of the maxpooling function. @@ -788,6 +930,9 @@ REGISTER_OP("MaxPoolGradWithArgmax") .Input("argmax: Targmax") .Output("output: T") .Attr("T: {float, half} = DT_FLOAT") + .SetShapeFn([](InferenceContext* c) { + return UnchangedShapeWithRank(c, 4); + }) .Doc(R"doc( Computes gradients of the maxpooling function. @@ -858,6 +1003,7 @@ REGISTER_OP("Dilation2DBackpropInput") .Attr("strides: list(int) >= 4") .Attr("rates: list(int) >= 4") .Attr(GetPaddingAttrString()) + .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Computes the gradient of morphological 2-D dilation with respect to the input. @@ -881,6 +1027,10 @@ REGISTER_OP("Dilation2DBackpropFilter") .Attr("strides: list(int) >= 4") .Attr("rates: list(int) >= 4") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(1)); + return Status::OK(); + }) .Doc(R"doc( Computes the gradient of morphological 2-D dilation with respect to the filter. @@ -910,6 +1060,7 @@ REGISTER_OP("ReluGrad") .Input("features: T") .Output("backprops: T") .Attr("T: realnumbertype") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( Computes rectified linear gradients for a Relu operation. @@ -932,6 +1083,7 @@ REGISTER_OP("Relu6Grad") .Input("features: T") .Output("backprops: T") .Attr("T: realnumbertype") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( Computes rectified linear 6 gradients for a Relu6 operation. @@ -957,6 +1109,7 @@ REGISTER_OP("EluGrad") .Input("outputs: T") .Output("backprops: T") .Attr("T: {float, double}") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( Computes gradients for the exponential linear (Elu) operation. @@ -979,6 +1132,7 @@ REGISTER_OP("SoftplusGrad") .Input("features: T") .Output("backprops: T") .Attr("T: realnumbertype") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( Computes softplus gradients for a softplus operation. @@ -1000,6 +1154,7 @@ REGISTER_OP("SoftsignGrad") .Input("features: T") .Output("backprops: T") .Attr("T: realnumbertype") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( Computes softsign gradients for a softsign operation. @@ -1050,6 +1205,16 @@ REGISTER_OP("SoftmaxCrossEntropyWithLogits") .Output("loss: T") .Output("backprop: T") .Attr("T: {half, float, double}") + .SetShapeFn([](InferenceContext* c) { + const Shape* input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); + TF_RETURN_IF_ERROR(c->Merge(input, c->input(1), &input)); + + const Dimension* batch_size = c->Dim(input, 0); + c->set_output(0, c->Vector(batch_size)); + c->set_output(1, input); + return Status::OK(); + }) .Doc(R"doc( Computes softmax cross entropy cost and gradients to backpropagate. @@ -1070,6 +1235,21 @@ REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits") .Output("backprop: T") .Attr("T: {half, float, double}") .Attr("Tlabels: {int32, int64} = DT_INT64") + .SetShapeFn([](InferenceContext* c) { + const Shape* features; + const Shape* labels; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels)); + + const Dimension* batch_size; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size)); + TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features)); + + c->set_output(0, c->Vector(batch_size)); + c->set_output(1, features); + return Status::OK(); + }) .Doc(R"doc( Computes softmax cross entropy cost and gradients to backpropagate. @@ -1095,6 +1275,17 @@ REGISTER_OP("InTopK") .Output("precision: bool") .Attr("k: int") .Attr("T: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + const Shape* predictions; + const Shape* targets; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets)); + const Dimension* batch_size; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size)); + c->set_output(0, c->Vector(batch_size)); + return Status::OK(); + }) .Doc(R"doc( Says whether the targets are in the top `K` predictions. diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc index 412584efd9..732bb38dca 100644 --- a/tensorflow/core/ops/nn_ops_test.cc +++ b/tensorflow/core/ops/nn_ops_test.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -TEST(ArrayOpsTest, TopK_ShapeFn) { +TEST(NNOpsTest, TopK_ShapeFn) { ShapeInferenceTestOp op("TopK"); auto set_k = [&op](int k) { TF_CHECK_OK(NodeDefBuilder("test", "Pack") @@ -51,7 +51,7 @@ TEST(ArrayOpsTest, TopK_ShapeFn) { INFER_ERROR("Need k >= 0, got -1", op, "[1,2,3,4]"); } -TEST(ArrayOpsTest, TopKV2_ShapeFn) { +TEST(NNOpsTest, TopKV2_ShapeFn) { ShapeInferenceTestOp op("TopKV2"); op.input_tensors.resize(2); @@ -80,4 +80,230 @@ TEST(ArrayOpsTest, TopKV2_ShapeFn) { op, "[1,2,3,4];[]"); } +TEST(NNOpsTest, InputTensorShapeOrUnknown2D_ShapeFn) { + typedef std::pair NameAndInputIndex; + for (const auto& p : + {NameAndInputIndex("AvgPoolGrad", 0), + NameAndInputIndex("Conv2DBackpropInput", 0), + NameAndInputIndex("Conv2DBackpropFilter", 1), + NameAndInputIndex("DepthwiseConv2dNativeBackpropInput", 0), + NameAndInputIndex("DepthwiseConv2dNativeBackpropFilter", 1)}) { + ShapeInferenceTestOp op(p.first); + op.input_tensors.resize(2); + + // When the input tensor is not known, the output is 4 unknown dims. + INFER_OK(op, "?;?", "[?,?,?,?]"); + INFER_OK(op, "[4];?", "[?,?,?,?]"); + + // When input tensor is known, its values determine output shape. + std::vector shape{1, 2, 3, 4}; + Tensor shape_t = test::AsTensor(shape); + op.input_tensors[p.second] = &shape_t; + INFER_OK(op, "[4];?", "[1,2,3,4]"); + } +} + +TEST(NNOpsTest, InputTensorShapeOrUnknown3D_ShapeFn) { + typedef std::pair NameAndInputIndex; + for (const auto& p : {NameAndInputIndex("AvgPool3DGrad", 0), + NameAndInputIndex("Conv3DBackpropInputV2", 0), + NameAndInputIndex("Conv3DBackpropFilterV2", 1)}) { + ShapeInferenceTestOp op(p.first); + op.input_tensors.resize(2); + + // When the input tensor is not known, the output is 4 unknown dims. + INFER_OK(op, "?;?;?", "[?,?,?,?,?]"); + INFER_OK(op, "[5];?;?", "[?,?,?,?,?]"); + + // When input tensor is known, its values determine output shape. + std::vector shape{1, 2, 3, 4, 5}; + Tensor shape_t = test::AsTensor(shape); + op.input_tensors[p.second] = &shape_t; + INFER_OK(op, "[5];?;?", "[1,2,3,4,5]"); + } +} + +TEST(NNOpsTest, BatchNormWithGlobalNormalization_ShapeFn) { + ShapeInferenceTestOp op("BatchNormWithGlobalNormalization"); + + // Test rank errors. + INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]"); + + // last dim of first input is merged with the single dim in other 4 inputs. + INFER_OK(op, "?;?;?;?;?", "[?,?,?,?]"); + INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0]"); + INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0]"); + INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0]"); + INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0]"); + INFER_OK(op, "[1,2,3,4];[4];[4];[4];[4]", + "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0|d3_0|d4_0]"); +} + +TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) { + ShapeInferenceTestOp op("BatchNormWithGlobalNormalizationGrad"); + + // Test rank errors. + INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?"); + INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op, + "?;?;?;?;[1,2,3]"); + + // The first output comes from the first and last inputs merged together. + // Other inputs are merged with the last dim of that merge result, and that + // merged vector dim is the last 4 outputs. + INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]"); + INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]"); + INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]"); + INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[d3_0];[d3_0]"); + INFER_OK(op, "[1,?,3,?];[?];[?];[?];[?,2,?,4]", + "[d0_0,d4_1,d0_2,d4_3];[d4_3];[d4_3];[d4_3];[d4_3]"); +} + +TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) { + ShapeInferenceTestOp op("Conv3DBackpropInput"); + + // Test rank error. + INFER_ERROR("Shape must be rank 5 but is rank 3", op, "[1,2,3];?;?"); + + // input[1] is transferred to output after asserting its rank. + INFER_OK(op, "?;?;?", "[?,?,?,?,?]"); + INFER_OK(op, "[?,?,?,?,?];?;?", "in0"); + INFER_OK(op, "[?,2,?,4,?];?;?", "in0"); +} + +TEST(NNOpsTest, Conv3DBackpropFilter_ShapeFn) { + ShapeInferenceTestOp op("Conv3DBackpropFilter"); + + // Test rank error. + INFER_ERROR("Shape must be rank 5 but is rank 3", op, "?;[1,2,3];?"); + + // input[1] is transferred to output after asserting its rank. + INFER_OK(op, "?;?;?", "[?,?,?,?,?]"); + INFER_OK(op, "?;[?,?,?,?,?];?", "in1"); + INFER_OK(op, "?;[?,2,?,4,?];?", "in1"); +} + +TEST(NNOpsTest, MaxPool3DGrad_ShapeFn) { + ShapeInferenceTestOp op("MaxPool3DGrad"); + + // Test rank error. + INFER_ERROR("Shape must be rank 5 but is rank 3", op, "[1,2,3];?;?"); + + // input[0] is transferred to output after asserting its rank. + INFER_OK(op, "?;?;?", "[?,?,?,?,?]"); + INFER_OK(op, "[?,?,?,?,?];?;?", "in0"); + INFER_OK(op, "[?,2,?,4,?];?;?", "in0"); +} + +TEST(NNOpsTest, LRNGrad_ShapeFn) { + ShapeInferenceTestOp op("LRNGrad"); + + // LRN Grad is a merge of all three inputs, of rank 4. + INFER_OK(op, "[1,?,?,4];[?,2,?,?];[?,?,3,?]", "[d0_0,d1_1,d2_2,d0_3]"); + + // Test rank errors. + INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?"); + INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op, "?;[1,2,3];?"); + INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op, "?;?;[1,2,3]"); +} + +TEST(NNOpsTest, MaxPoolGrad_ShapeFn) { + for (const char* op_name : {"MaxPoolGrad", "MaxPoolGradWithArgmax"}) { + ShapeInferenceTestOp op(op_name); + + // Test rank error. + INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?"); + + // input[0] is transferred to output after asserting its rank. + INFER_OK(op, "?;?;?", "[?,?,?,?]"); + INFER_OK(op, "[?,?,?,?];?;?", "in0"); + INFER_OK(op, "[?,2,?,4];?;?", "in0"); + } +} + +TEST(NNOpsTest, Dilation2DBackpropInput_ShapeFn) { + ShapeInferenceTestOp op("Dilation2DBackpropInput"); + + // input[0] is transferred to output. + INFER_OK(op, "?;?;?", "in0"); + INFER_OK(op, "?;[?,?,?,?,?];?", "in0"); + INFER_OK(op, "?;[?,2,?,4,?];?", "in0"); +} + +TEST(NNOpsTest, Dilation2DBackpropFilter_ShapeFn) { + ShapeInferenceTestOp op("Dilation2DBackpropFilter"); + + // input[1] is transferred to output. + INFER_OK(op, "?;?;?", "in1"); + INFER_OK(op, "?;[?,?,?,?,?];?", "in1"); + INFER_OK(op, "?;[?,2,?,4,?];?", "in1"); +} + +TEST(NNOpsTest, MergeBothInputs_ShapeFn) { + for (const char* op_name : + {"ReluGrad", "Relu6Grad", "EluGrad", "SoftplusGrad", "SoftsignGrad"}) { + ShapeInferenceTestOp op(op_name); + + INFER_OK(op, "?;?", "in0|in1"); + INFER_OK(op, "?;[1,?,3]", "in1"); + INFER_OK(op, "[1,?,3];?", "in0"); + INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]"); + INFER_ERROR("Dimension 1 in both shapes must be equal, but are 3 and 2", op, + "[1,3];[?,2]"); + } +} + +TEST(NNOpsTest, SoftmaxCrossEntropyWithLogits_ShapeFn) { + ShapeInferenceTestOp op("SoftmaxCrossEntropyWithLogits"); + + // Inputs are [batch_size,N] and [batch_size,N], and outputs are [batch_size] + // and + // [batch_size,N]. + INFER_OK(op, "?;?", "[?];[?,?]"); + INFER_OK(op, "[?,?];[?,?]", "[d0_0|d1_0];in0|in1"); + INFER_OK(op, "[1,2];[?,2]", "[d0_0];in0"); + INFER_OK(op, "[1,?];[?,2]", "[d0_0];[d0_0,d0_1|d1_1]"); + INFER_OK(op, "[?,2];[1,2]", "[d1_0];in1"); + + INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op, + "[1,?];[2,?]"); + INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3];?"); + INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op, "?;[1,2,3]"); +} + +TEST(NNOpsTest, SparseSoftmaxCrossEntropyWithLogits_ShapeFn) { + ShapeInferenceTestOp op("SparseSoftmaxCrossEntropyWithLogits"); + + // Inputs are [batch_size,N] and [batch_size], and outputs are [batch_size] + // and [batch_size,N]. + INFER_OK(op, "?;?", "[?];[?,?]"); + INFER_OK(op, "[?,?];[?]", "[d0_0|d1_0];[d0_0|d1_0,d0_1]"); + INFER_OK(op, "[1,2];[1]", "[d0_0|d1_0];[d0_0|d1_0,d0_1]"); + INFER_OK(op, "[?,2];[1]", "[d1_0];[d1_0,d0_1]"); + + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?];[2]"); + INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3];?"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]"); +} + +TEST(NNOpsTest, InTopK_ShapeFn) { + ShapeInferenceTestOp op("InTopK"); + + // Inputs are [batch_size,N] and [batch_size], and output is [batch_size]. + INFER_OK(op, "?;?", "[?]"); + INFER_OK(op, "[?,?];[?]", "[d0_0|d1_0]"); + INFER_OK(op, "[1,2];[1]", "[d0_0|d1_0]"); + INFER_OK(op, "[?,2];[1]", "[d1_0]"); + + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?];[2]"); + INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3];?"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]"); +} + } // end namespace tensorflow -- cgit v1.2.3