aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-19 16:22:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-19 17:33:10 -0700
commitf3a613e9db95958316569d74748d4fdb632ffbb4 (patch)
tree2e722d5bffd912ab275bef9c9094249773f4a4d8
parent15afaf7b2ed6aa77b62dde0007afffaffc6a051d (diff)
Add C++ shape inference functions for more functions in nn_ops.cc.
Add shape_inference::InferenceContext::ReplaceDim. Change: 127893881
-rw-r--r--tensorflow/core/framework/common_shape_fns.h8
-rw-r--r--tensorflow/core/framework/shape_inference.cc11
-rw-r--r--tensorflow/core/framework/shape_inference.h5
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc18
-rw-r--r--tensorflow/core/ops/linalg_ops.cc9
-rw-r--r--tensorflow/core/ops/nn_ops.cc191
-rw-r--r--tensorflow/core/ops/nn_ops_test.cc230
7 files changed, 465 insertions, 7 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 9bb329e520..2439751991 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -61,6 +61,14 @@ inline Status ScalarShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+// Shape function for binary ops where both inputs and the output match.
+inline Status MergeBothInputsShapeFn(InferenceContext* c) {
+ const Shape* out;
+ TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
+ c->set_output(0, out);
+ return Status::OK();
+}
+
inline Status MatMulShape(shape_inference::InferenceContext* c) {
const Shape* a;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index c84856b058..1f5f3e9357 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -329,6 +329,17 @@ Status InferenceContext::Concatenate(const Shape* s1, const Shape* s2,
return ReturnCreatedShape(dims, out);
}
+Status InferenceContext::ReplaceDim(const Shape* s, int dim_index,
+ const Dimension* new_dim,
+ const Shape** out) {
+ if (!RankKnown(s)) {
+ return ReturnUnknownShape(out);
+ }
+ std::vector<const Dimension*> dims(s->dims_);
+ dims[dim_index] = new_dim;
+ return ReturnCreatedShape(dims, out);
+}
+
const Shape* InferenceContext::MakeShape(
const std::vector<const Dimension*>& dims) {
all_shapes_.push_back(new Shape(dims));
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 4cb80233e9..eec4b0f263 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -192,6 +192,11 @@ class InferenceContext {
Status Concatenate(const Shape* s1, const Shape* s2,
const Shape** out) TF_MUST_USE_RESULT;
+ // Returns in <out> the shape from replacing <s.dim[dim_index]> with
+ // <new_dim>.
+ Status ReplaceDim(const Shape* s, int dim_index, const Dimension* new_dim,
+ const Shape** out) TF_MUST_USE_RESULT;
+
// Returns a new shape with the given dims. The returned value is owned by
// this context.
const Shape* MakeShape(const std::vector<const Dimension*>& dims);
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index 340672d5d1..323992ac5d 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -465,6 +465,24 @@ TEST(ShapeInferenceTest, Concatenate) {
}
}
+TEST(ShapeInferenceTest, ReplaceDim) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(2, 0), {"[1,2,3]", "?"}, {});
+
+ auto in = c.input(0);
+ auto unknown = c.input(1);
+
+ const Shape* replaced;
+ EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok());
+ EXPECT_EQ("[2,2,3]", c.DebugString(replaced));
+ EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok());
+ EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
+ EXPECT_TRUE(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced).ok());
+ EXPECT_EQ("[1,3,3]", c.DebugString(replaced));
+ EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok());
+ EXPECT_EQ("?", c.DebugString(replaced));
+}
+
TEST(ShapeInferenceTest, MakeShape) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {"[1,2,3,?,5]"}, {});
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index 85442c7f66..25887ebd77 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -71,13 +71,12 @@ Status SquareMatrixSolveShapeFn(InferenceContext* c) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &rhs));
// lhs and rhs have the same number of rows. Make a new output
- // shape that has the merged-rows and the rest of the rhs.
+ // shape that uses rows to replace rhs.dim[0].
const Dimension* rows;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, 0), c->Dim(rhs, 0), &rows));
- const Shape* rhs_remaining;
- TF_RETURN_IF_ERROR(c->Subshape(rhs, 1, &rhs_remaining));
- TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(rows), rhs_remaining, &rhs));
- c->set_output(0, rhs);
+ const Shape* out;
+ TF_RETURN_IF_ERROR(c->ReplaceDim(rhs, 0, rows, &out));
+ c->set_output(0, out);
return Status::OK();
}
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 8311f47f3f..5671d042d5 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -25,6 +26,28 @@ typedef shape_inference::Dimension Dimension;
typedef shape_inference::InferenceContext InferenceContext;
typedef shape_inference::Shape Shape;
+namespace {
+
+// A shape function that uses the tensor value at <input_idx> as a shape for
+// output 0. If the tensor value is not available, it uses a shape with <ndims>
+// unknown dims.
+Status InputTensorShapeOrUnknown(InferenceContext* c, int input_idx,
+ int ndims) {
+ const Shape* out;
+ const Tensor* input = c->input_tensor(input_idx);
+ if (input == nullptr) {
+ std::vector<const Dimension*> dims;
+ for (int i = 0; i < ndims; ++i) dims.push_back(c->UnknownDim());
+ out = c->MakeShape(dims);
+ } else {
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(input_idx, &out));
+ }
+ c->set_output(0, out);
+ return Status::OK();
+}
+
+} // namespace
+
// --------------------------------------------------------------------------
REGISTER_OP("AvgPool")
@@ -62,6 +85,13 @@ REGISTER_OP("AvgPoolGrad")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("T: {float, half, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ // NOTE(mrry): We could in principle work out the shape from the
+ // gradients and the attrs, but if we do not know orig_input_shape
+ // statically, then we are unlikely to know the shape of the
+ // gradients either.
+ return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ })
.Doc(R"doc(
Computes gradients of the average pooling function.
@@ -92,6 +122,22 @@ REGISTER_OP("BatchNormWithGlobalNormalization")
.Attr("variance_epsilon: float")
.Attr("scale_after_normalization: bool")
.Deprecated(9, "Use tf.nn.batch_normalization()")
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+
+ const Dimension* last_dim = c->Dim(input, 3);
+ for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
+ const Shape* vec;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
+ TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
+ }
+
+ const Shape* out;
+ TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ })
.Doc(R"doc(
Batch normalization.
@@ -129,6 +175,30 @@ REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
.Attr("variance_epsilon: float")
.Attr("scale_after_normalization: bool")
.Deprecated(9, "Use tf.nn.batch_normalization()")
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+ TF_RETURN_IF_ERROR(
+ c->Merge(input, c->input(4), &input)); // with backprop
+
+ const Dimension* last_dim = c->Dim(input, 3);
+ for (int i = 1; i < 4; ++i) { // covers m, v, gamma
+ const Shape* vec;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
+ TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
+ }
+
+ const Shape* dx;
+ TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx));
+ c->set_output(0, dx);
+
+ const Shape* vector_shape = c->Vector(last_dim);
+ c->set_output(1, vector_shape);
+ c->set_output(2, vector_shape);
+ c->set_output(3, vector_shape);
+ c->set_output(4, vector_shape);
+ return Status::OK();
+ })
.Doc(R"doc(
Gradients for batch normalization.
@@ -280,6 +350,13 @@ REGISTER_OP("Conv2DBackpropInput")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ // NOTE(mrry): We could in principle work out the shape from the
+ // gradients and the attrs, but if we do not know orig_input_shape
+ // statically, then we are unlikely to know the shape of the
+ // gradients either.
+ return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ })
.Doc(R"doc(
Computes the gradients of convolution with respect to the input.
@@ -315,6 +392,13 @@ REGISTER_OP("Conv2DBackpropFilter")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ // NOTE(mrry): We could in principle work out the shape from the
+ // gradients and the attrs, but if we do not know orig_input_shape
+ // statically, then we are unlikely to know the shape of the
+ // gradients either.
+ return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
+ })
.Doc(R"doc(
Computes the gradients of convolution with respect to the filter.
@@ -380,6 +464,13 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
.Attr("T: {float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ // NOTE(mrry): We could in principle work out the shape from the
+ // gradients and the attrs, but if we do not know orig_input_shape
+ // statically, then we are unlikely to know the shape of the
+ // gradients either.
+ return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ })
.Doc(R"doc(
Computes the gradients of depthwise convolution with respect to the input.
@@ -404,6 +495,13 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
.Attr("T: {float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ // NOTE(mrry): We could in principle work out the shape from the
+ // gradients and the attrs, but if we do not know orig_input_shape
+ // statically, then we are unlikely to know the shape of the
+ // gradients either.
+ return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
+ })
.Doc(R"doc(
Computes the gradients of depthwise convolution with respect to the filter.
@@ -456,6 +554,9 @@ REGISTER_OP("Conv3DBackpropInput")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Deprecated(10, "Use Conv3DBackpropInputV2")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 5);
+ })
.Doc(R"doc(
Computes the gradients of 3-D convolution with respect to the input.
@@ -479,6 +580,12 @@ REGISTER_OP("Conv3DBackpropFilter")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Deprecated(10, "Use Conv3DBackpropFilterV2")
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* out;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ })
.Doc(R"doc(
Computes the gradients of 3-D convolution with respect to the filter.
@@ -501,6 +608,13 @@ REGISTER_OP("Conv3DBackpropInputV2")
.Attr("T: numbertype")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
.Doc(R"doc(
Computes the gradients of 3-D convolution with respect to the input.
@@ -525,6 +639,13 @@ REGISTER_OP("Conv3DBackpropFilterV2")
.Attr("T: numbertype")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
.Doc(R"doc(
Computes the gradients of 3-D convolution with respect to the filter.
@@ -570,6 +691,13 @@ REGISTER_OP("AvgPool3DGrad")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr("T: numbertype")
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
.Doc(R"doc(
Computes gradients of average pooling function.
@@ -613,6 +741,9 @@ REGISTER_OP("MaxPool3DGrad")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr("T: numbertype")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 5);
+ })
.Doc(R"doc(
Computes gradients of max pooling function.
@@ -686,6 +817,14 @@ REGISTER_OP("LRNGrad")
.Attr("alpha: float = 1.0")
.Attr("beta: float = 0.5")
.Attr("T: {float, half} = DT_FLOAT")
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* s;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
+ TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image
+ TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image
+ c->set_output(0, s);
+ return Status::OK();
+ })
.Doc(R"doc(
Gradients for Local Response Normalization.
@@ -735,6 +874,9 @@ REGISTER_OP("MaxPoolGrad")
.Input("grad: T")
.Output("output: T")
.Attr("T: {float, half} = DT_FLOAT")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 4);
+ })
.Doc(R"doc(
Computes gradients of the maxpooling function.
@@ -788,6 +930,9 @@ REGISTER_OP("MaxPoolGradWithArgmax")
.Input("argmax: Targmax")
.Output("output: T")
.Attr("T: {float, half} = DT_FLOAT")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 4);
+ })
.Doc(R"doc(
Computes gradients of the maxpooling function.
@@ -858,6 +1003,7 @@ REGISTER_OP("Dilation2DBackpropInput")
.Attr("strides: list(int) >= 4")
.Attr("rates: list(int) >= 4")
.Attr(GetPaddingAttrString())
+ .SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Computes the gradient of morphological 2-D dilation with respect to the input.
@@ -881,6 +1027,10 @@ REGISTER_OP("Dilation2DBackpropFilter")
.Attr("strides: list(int) >= 4")
.Attr("rates: list(int) >= 4")
.Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->input(1));
+ return Status::OK();
+ })
.Doc(R"doc(
Computes the gradient of morphological 2-D dilation with respect to the filter.
@@ -910,6 +1060,7 @@ REGISTER_OP("ReluGrad")
.Input("features: T")
.Output("backprops: T")
.Attr("T: realnumbertype")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes rectified linear gradients for a Relu operation.
@@ -932,6 +1083,7 @@ REGISTER_OP("Relu6Grad")
.Input("features: T")
.Output("backprops: T")
.Attr("T: realnumbertype")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes rectified linear 6 gradients for a Relu6 operation.
@@ -957,6 +1109,7 @@ REGISTER_OP("EluGrad")
.Input("outputs: T")
.Output("backprops: T")
.Attr("T: {float, double}")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes gradients for the exponential linear (Elu) operation.
@@ -979,6 +1132,7 @@ REGISTER_OP("SoftplusGrad")
.Input("features: T")
.Output("backprops: T")
.Attr("T: realnumbertype")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes softplus gradients for a softplus operation.
@@ -1000,6 +1154,7 @@ REGISTER_OP("SoftsignGrad")
.Input("features: T")
.Output("backprops: T")
.Attr("T: realnumbertype")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes softsign gradients for a softsign operation.
@@ -1050,6 +1205,16 @@ REGISTER_OP("SoftmaxCrossEntropyWithLogits")
.Output("loss: T")
.Output("backprop: T")
.Attr("T: {half, float, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
+ TF_RETURN_IF_ERROR(c->Merge(input, c->input(1), &input));
+
+ const Dimension* batch_size = c->Dim(input, 0);
+ c->set_output(0, c->Vector(batch_size));
+ c->set_output(1, input);
+ return Status::OK();
+ })
.Doc(R"doc(
Computes softmax cross entropy cost and gradients to backpropagate.
@@ -1070,6 +1235,21 @@ REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
.Output("backprop: T")
.Attr("T: {half, float, double}")
.Attr("Tlabels: {int32, int64} = DT_INT64")
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* features;
+ const Shape* labels;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels));
+
+ const Dimension* batch_size;
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size));
+ TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features));
+
+ c->set_output(0, c->Vector(batch_size));
+ c->set_output(1, features);
+ return Status::OK();
+ })
.Doc(R"doc(
Computes softmax cross entropy cost and gradients to backpropagate.
@@ -1095,6 +1275,17 @@ REGISTER_OP("InTopK")
.Output("precision: bool")
.Attr("k: int")
.Attr("T: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* predictions;
+ const Shape* targets;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
+ const Dimension* batch_size;
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
+ c->set_output(0, c->Vector(batch_size));
+ return Status::OK();
+ })
.Doc(R"doc(
Says whether the targets are in the top `K` predictions.
diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc
index 412584efd9..732bb38dca 100644
--- a/tensorflow/core/ops/nn_ops_test.cc
+++ b/tensorflow/core/ops/nn_ops_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
namespace tensorflow {
-TEST(ArrayOpsTest, TopK_ShapeFn) {
+TEST(NNOpsTest, TopK_ShapeFn) {
ShapeInferenceTestOp op("TopK");
auto set_k = [&op](int k) {
TF_CHECK_OK(NodeDefBuilder("test", "Pack")
@@ -51,7 +51,7 @@ TEST(ArrayOpsTest, TopK_ShapeFn) {
INFER_ERROR("Need k >= 0, got -1", op, "[1,2,3,4]");
}
-TEST(ArrayOpsTest, TopKV2_ShapeFn) {
+TEST(NNOpsTest, TopKV2_ShapeFn) {
ShapeInferenceTestOp op("TopKV2");
op.input_tensors.resize(2);
@@ -80,4 +80,230 @@ TEST(ArrayOpsTest, TopKV2_ShapeFn) {
op, "[1,2,3,4];[]");
}
+TEST(NNOpsTest, InputTensorShapeOrUnknown2D_ShapeFn) {
+ typedef std::pair<const char*, int> NameAndInputIndex;
+ for (const auto& p :
+ {NameAndInputIndex("AvgPoolGrad", 0),
+ NameAndInputIndex("Conv2DBackpropInput", 0),
+ NameAndInputIndex("Conv2DBackpropFilter", 1),
+ NameAndInputIndex("DepthwiseConv2dNativeBackpropInput", 0),
+ NameAndInputIndex("DepthwiseConv2dNativeBackpropFilter", 1)}) {
+ ShapeInferenceTestOp op(p.first);
+ op.input_tensors.resize(2);
+
+ // When the input tensor is not known, the output is 4 unknown dims.
+ INFER_OK(op, "?;?", "[?,?,?,?]");
+ INFER_OK(op, "[4];?", "[?,?,?,?]");
+
+ // When input tensor is known, its values determine output shape.
+ std::vector<int32> shape{1, 2, 3, 4};
+ Tensor shape_t = test::AsTensor<int32>(shape);
+ op.input_tensors[p.second] = &shape_t;
+ INFER_OK(op, "[4];?", "[1,2,3,4]");
+ }
+}
+
+TEST(NNOpsTest, InputTensorShapeOrUnknown3D_ShapeFn) {
+ typedef std::pair<const char*, int> NameAndInputIndex;
+ for (const auto& p : {NameAndInputIndex("AvgPool3DGrad", 0),
+ NameAndInputIndex("Conv3DBackpropInputV2", 0),
+ NameAndInputIndex("Conv3DBackpropFilterV2", 1)}) {
+ ShapeInferenceTestOp op(p.first);
+ op.input_tensors.resize(2);
+
+ // When the input tensor is not known, the output is 4 unknown dims.
+ INFER_OK(op, "?;?;?", "[?,?,?,?,?]");
+ INFER_OK(op, "[5];?;?", "[?,?,?,?,?]");
+
+ // When input tensor is known, its values determine output shape.
+ std::vector<int32> shape{1, 2, 3, 4, 5};
+ Tensor shape_t = test::AsTensor<int32>(shape);
+ op.input_tensors[p.second] = &shape_t;
+ INFER_OK(op, "[5];?;?", "[1,2,3,4,5]");
+ }
+}
+
+TEST(NNOpsTest, BatchNormWithGlobalNormalization_ShapeFn) {
+ ShapeInferenceTestOp op("BatchNormWithGlobalNormalization");
+
+ // Test rank errors.
+ INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
+
+ // last dim of first input is merged with the single dim in other 4 inputs.
+ INFER_OK(op, "?;?;?;?;?", "[?,?,?,?]");
+ INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0]");
+ INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0]");
+ INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0]");
+ INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0]");
+ INFER_OK(op, "[1,2,3,4];[4];[4];[4];[4]",
+ "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0|d3_0|d4_0]");
+}
+
+TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) {
+ ShapeInferenceTestOp op("BatchNormWithGlobalNormalizationGrad");
+
+ // Test rank errors.
+ INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
+ INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op,
+ "?;?;?;?;[1,2,3]");
+
+ // The first output comes from the first and last inputs merged together.
+ // Other inputs are merged with the last dim of that merge result, and that
+ // merged vector dim is the last 4 outputs.
+ INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
+ INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
+ INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
+ INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[d3_0];[d3_0]");
+ INFER_OK(op, "[1,?,3,?];[?];[?];[?];[?,2,?,4]",
+ "[d0_0,d4_1,d0_2,d4_3];[d4_3];[d4_3];[d4_3];[d4_3]");
+}
+
+TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {
+ ShapeInferenceTestOp op("Conv3DBackpropInput");
+
+ // Test rank error.
+ INFER_ERROR("Shape must be rank 5 but is rank 3", op, "[1,2,3];?;?");
+
+ // input[1] is transferred to output after asserting its rank.
+ INFER_OK(op, "?;?;?", "[?,?,?,?,?]");
+ INFER_OK(op, "[?,?,?,?,?];?;?", "in0");
+ INFER_OK(op, "[?,2,?,4,?];?;?", "in0");
+}
+
+TEST(NNOpsTest, Conv3DBackpropFilter_ShapeFn) {
+ ShapeInferenceTestOp op("Conv3DBackpropFilter");
+
+ // Test rank error.
+ INFER_ERROR("Shape must be rank 5 but is rank 3", op, "?;[1,2,3];?");
+
+ // input[1] is transferred to output after asserting its rank.
+ INFER_OK(op, "?;?;?", "[?,?,?,?,?]");
+ INFER_OK(op, "?;[?,?,?,?,?];?", "in1");
+ INFER_OK(op, "?;[?,2,?,4,?];?", "in1");
+}
+
+TEST(NNOpsTest, MaxPool3DGrad_ShapeFn) {
+ ShapeInferenceTestOp op("MaxPool3DGrad");
+
+ // Test rank error.
+ INFER_ERROR("Shape must be rank 5 but is rank 3", op, "[1,2,3];?;?");
+
+ // input[0] is transferred to output after asserting its rank.
+ INFER_OK(op, "?;?;?", "[?,?,?,?,?]");
+ INFER_OK(op, "[?,?,?,?,?];?;?", "in0");
+ INFER_OK(op, "[?,2,?,4,?];?;?", "in0");
+}
+
+TEST(NNOpsTest, LRNGrad_ShapeFn) {
+ ShapeInferenceTestOp op("LRNGrad");
+
+ // LRN Grad is a merge of all three inputs, of rank 4.
+ INFER_OK(op, "[1,?,?,4];[?,2,?,?];[?,?,3,?]", "[d0_0,d1_1,d2_2,d0_3]");
+
+ // Test rank errors.
+ INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?");
+ INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op, "?;[1,2,3];?");
+ INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op, "?;?;[1,2,3]");
+}
+
+TEST(NNOpsTest, MaxPoolGrad_ShapeFn) {
+ for (const char* op_name : {"MaxPoolGrad", "MaxPoolGradWithArgmax"}) {
+ ShapeInferenceTestOp op(op_name);
+
+ // Test rank error.
+ INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?");
+
+ // input[0] is transferred to output after asserting its rank.
+ INFER_OK(op, "?;?;?", "[?,?,?,?]");
+ INFER_OK(op, "[?,?,?,?];?;?", "in0");
+ INFER_OK(op, "[?,2,?,4];?;?", "in0");
+ }
+}
+
+TEST(NNOpsTest, Dilation2DBackpropInput_ShapeFn) {
+ ShapeInferenceTestOp op("Dilation2DBackpropInput");
+
+ // input[0] is transferred to output.
+ INFER_OK(op, "?;?;?", "in0");
+ INFER_OK(op, "?;[?,?,?,?,?];?", "in0");
+ INFER_OK(op, "?;[?,2,?,4,?];?", "in0");
+}
+
+TEST(NNOpsTest, Dilation2DBackpropFilter_ShapeFn) {
+ ShapeInferenceTestOp op("Dilation2DBackpropFilter");
+
+ // input[1] is transferred to output.
+ INFER_OK(op, "?;?;?", "in1");
+ INFER_OK(op, "?;[?,?,?,?,?];?", "in1");
+ INFER_OK(op, "?;[?,2,?,4,?];?", "in1");
+}
+
+TEST(NNOpsTest, MergeBothInputs_ShapeFn) {
+ for (const char* op_name :
+ {"ReluGrad", "Relu6Grad", "EluGrad", "SoftplusGrad", "SoftsignGrad"}) {
+ ShapeInferenceTestOp op(op_name);
+
+ INFER_OK(op, "?;?", "in0|in1");
+ INFER_OK(op, "?;[1,?,3]", "in1");
+ INFER_OK(op, "[1,?,3];?", "in0");
+ INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]");
+ INFER_ERROR("Dimension 1 in both shapes must be equal, but are 3 and 2", op,
+ "[1,3];[?,2]");
+ }
+}
+
+TEST(NNOpsTest, SoftmaxCrossEntropyWithLogits_ShapeFn) {
+ ShapeInferenceTestOp op("SoftmaxCrossEntropyWithLogits");
+
+ // Inputs are [batch_size,N] and [batch_size,N], and outputs are [batch_size]
+ // and
+ // [batch_size,N].
+ INFER_OK(op, "?;?", "[?];[?,?]");
+ INFER_OK(op, "[?,?];[?,?]", "[d0_0|d1_0];in0|in1");
+ INFER_OK(op, "[1,2];[?,2]", "[d0_0];in0");
+ INFER_OK(op, "[1,?];[?,2]", "[d0_0];[d0_0,d0_1|d1_1]");
+ INFER_OK(op, "[?,2];[1,2]", "[d1_0];in1");
+
+ INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
+ "[1,?];[2,?]");
+ INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3];?");
+ INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op, "?;[1,2,3]");
+}
+
+TEST(NNOpsTest, SparseSoftmaxCrossEntropyWithLogits_ShapeFn) {
+ ShapeInferenceTestOp op("SparseSoftmaxCrossEntropyWithLogits");
+
+ // Inputs are [batch_size,N] and [batch_size], and outputs are [batch_size]
+ // and [batch_size,N].
+ INFER_OK(op, "?;?", "[?];[?,?]");
+ INFER_OK(op, "[?,?];[?]", "[d0_0|d1_0];[d0_0|d1_0,d0_1]");
+ INFER_OK(op, "[1,2];[1]", "[d0_0|d1_0];[d0_0|d1_0,d0_1]");
+ INFER_OK(op, "[?,2];[1]", "[d1_0];[d1_0,d0_1]");
+
+ INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?];[2]");
+ INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3];?");
+ INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
+}
+
+TEST(NNOpsTest, InTopK_ShapeFn) {
+ ShapeInferenceTestOp op("InTopK");
+
+ // Inputs are [batch_size,N] and [batch_size], and output is [batch_size].
+ INFER_OK(op, "?;?", "[?]");
+ INFER_OK(op, "[?,?];[?]", "[d0_0|d1_0]");
+ INFER_OK(op, "[1,2];[1]", "[d0_0|d1_0]");
+ INFER_OK(op, "[?,2];[1]", "[d1_0]");
+
+ INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?];[2]");
+ INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3];?");
+ INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
+}
+
} // end namespace tensorflow