aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/nn_ops_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/nn_ops_test.cc')
-rw-r--r--tensorflow/core/ops/nn_ops_test.cc24
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc
index 4628b725f8..94ecf4d5db 100644
--- a/tensorflow/core/ops/nn_ops_test.cc
+++ b/tensorflow/core/ops/nn_ops_test.cc
@@ -81,6 +81,30 @@ TEST(NNOpsTest, TopKV2_ShapeFn) {
op, "[1,2,3,4];[]");
}
+TEST(NNOpsTest, NthElement_ShapeFn) {
+ ShapeInferenceTestOp op("NthElement");
+ op.input_tensors.resize(2);
+
+ Tensor n_t;
+ op.input_tensors[1] = &n_t;
+ n_t = test::AsScalar<int32>(20);
+
+ INFER_OK(op, "?;[]", "?");
+ INFER_OK(op, "[21];[]", "[]");
+ INFER_OK(op, "[2,?,?];[]", "[d0_0,d0_1]");
+ INFER_OK(op, "[?,3,?,21];[]", "[d0_0,d0_1,d0_2]");
+
+ INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
+ INFER_ERROR("Input must have last dimension > n = 20 but is 1", op,
+ "[1];[]");
+ INFER_ERROR("Input must have last dimension > n = 20 but is 20", op,
+ "[1,2,3,20];[]");
+ n_t = test::AsScalar<int32>(-1);
+ INFER_ERROR(
+ "Dimension size, given by scalar input 1, must be non-negative but is -1",
+ op, "[1,2,3,4];[]");
+}
+
TEST(NNOpsTest, BatchNormWithGlobalNormalization_ShapeFn) {
ShapeInferenceTestOp op("BatchNormWithGlobalNormalization");