diff options
Diffstat (limited to 'tensorflow/core/ops/nn_ops_test.cc')
-rw-r--r-- | tensorflow/core/ops/nn_ops_test.cc | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc index 4628b725f8..94ecf4d5db 100644 --- a/tensorflow/core/ops/nn_ops_test.cc +++ b/tensorflow/core/ops/nn_ops_test.cc @@ -81,6 +81,30 @@ TEST(NNOpsTest, TopKV2_ShapeFn) { op, "[1,2,3,4];[]"); } +TEST(NNOpsTest, NthElement_ShapeFn) { + ShapeInferenceTestOp op("NthElement"); + op.input_tensors.resize(2); + + Tensor n_t; + op.input_tensors[1] = &n_t; + n_t = test::AsScalar<int32>(20); + + INFER_OK(op, "?;[]", "?"); + INFER_OK(op, "[21];[]", "[]"); + INFER_OK(op, "[2,?,?];[]", "[d0_0,d0_1]"); + INFER_OK(op, "[?,3,?,21];[]", "[d0_0,d0_1,d0_2]"); + + INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]"); + INFER_ERROR("Input must have last dimension > n = 20 but is 1", op, + "[1];[]"); + INFER_ERROR("Input must have last dimension > n = 20 but is 20", op, + "[1,2,3,20];[]"); + n_t = test::AsScalar<int32>(-1); + INFER_ERROR( + "Dimension size, given by scalar input 1, must be non-negative but is -1", + op, "[1,2,3,4];[]"); +} + TEST(NNOpsTest, BatchNormWithGlobalNormalization_ShapeFn) { ShapeInferenceTestOp op("BatchNormWithGlobalNormalization"); |