diff options
Diffstat (limited to 'tensorflow/core/ops/parsing_ops_test.cc')
-rw-r--r-- | tensorflow/core/ops/parsing_ops_test.cc | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc index 9121d7ae92..c65e66d1a8 100644 --- a/tensorflow/core/ops/parsing_ops_test.cc +++ b/tensorflow/core/ops/parsing_ops_test.cc @@ -143,6 +143,88 @@ TEST(ParsingOpsTest, ParseExample_ShapeFn) { "?;?;?;?;?;?;?;?;?;?"); } +TEST(ParsingOpsTest, ParseSequenceExample_ShapeFn) { + ShapeInferenceTestOp op("ParseSequenceExample"); + auto set_outputs = [&op](int num_context_sparse, int num_context_dense, + int num_feature_list_sparse, + int num_feature_list_dense, + bool add_extra_shape = false) { + using NodeOutList = std::vector<NodeDefBuilder::NodeOut>; + using DataTypeList = std::vector<DataType>; + string string_in("test"); + NodeDefBuilder::NodeOut node_in{"a", 0, DT_STRING}; + TF_ASSERT_OK( + NodeDefBuilder("test", "ParseSequenceExample") + .Input("serialized", 0, DT_STRING) + .Input("debug_name", 0, DT_STRING) + .Input(NodeOutList(num_context_dense, node_in)) + .Attr("Ncontext_sparse", num_context_sparse) + .Attr("Ncontext_dense", num_context_dense) + .Attr("Nfeature_list_sparse", num_feature_list_sparse) + .Attr("Nfeature_list_dense", num_feature_list_dense) + .Attr("feature_list_dense_missing_assumed_empty", + std::vector<string>(num_feature_list_dense, string_in)) + .Attr("context_sparse_keys", + std::vector<string>(num_context_sparse, string_in)) + .Attr("context_dense_keys", + std::vector<string>(num_context_dense, string_in)) + .Attr("feature_list_sparse_keys", + std::vector<string>(num_feature_list_sparse, string_in)) + .Attr("feature_list_dense_keys", + std::vector<string>(num_feature_list_dense, string_in)) + .Attr("context_sparse_types", + DataTypeList(num_context_sparse, DT_FLOAT)) + .Attr("context_dense_types", + DataTypeList(num_context_dense, DT_FLOAT)) + .Attr("context_dense_shapes", + MakeDenseShapes(num_context_dense, add_extra_shape, 0)) + .Attr("feature_list_sparse_types", + DataTypeList(num_feature_list_sparse, DT_FLOAT)) + .Attr("feature_list_dense_types", + DataTypeList(num_feature_list_dense, DT_FLOAT)) + .Attr("feature_list_dense_shapes", + MakeDenseShapes(num_feature_list_dense, add_extra_shape, 0)) + .Finalize(&op.node_def)); + }; + + // Verify inputs 'serialized' and 'debug_name'. + set_outputs(0, 0, 0, 0); + INFER_OK(op, "[?];[?]", ""); + INFER_OK(op, "[8];[8]", ""); + INFER_ERROR("must be rank 1", op, "[];[?]"); + INFER_ERROR("must be rank 1", op, "[?];[]"); + + // context inputs with no feature_list inputs. + set_outputs(2 /* num_context_sparse */, 3 /* num_context_dense */, 0, 0); + INFER_OK(op, "[?];[?];?;?;?", + ("[?,2];[?,2];[?];[?];[2];[2];" // context sparse + "[d0_0,1];[d0_0,1,2];[d0_0,1,2,3]")); // context dense + + // feature_list inputs with no context inputs. + set_outputs(0, 0, 2 /* num_feature_list_sparse */, + 3 /* num_feature_list_dense */); + INFER_OK(op, "[?];[?]", + ("[?,3];[?,3];[?];[?];[3];[3];" // feature_list sparse + "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3];" // feature_list dense + "[d0_0];[d0_0];[d0_0]")); // feature_list length + + // Combine previous two test cases. + set_outputs(2, 3, 2, 3); + INFER_OK(op, "[7];[7];?;?;?", + ("[?,2];[?,2];[?];[?];[2];[2];" // context sparse + "[d0_0,1];[d0_0,1,2];[d0_0,1,2,3];" // context dense + "[?,3];[?,3];[?];[?];[3];[3];" // feature_list sparse + "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3];" // feature_list dense + "[d0_0];[d0_0];[d0_0]")); // feature_list length + + // Confirm an error from ParseSequenceExampleAttrs.Init(). + set_outputs(1, 1, 1, 1, true /* add_extra_shape */); + INFER_ERROR( + "num_context_dense (1) must match the size of context_dense_keys (1), " + "context_dense_types (1) and context_dense_shapes (2)", + op, "[?];[?];?"); +} + TEST(ParsingOpsTest, ParseSingleSequenceExample_ShapeFn) { ShapeInferenceTestOp op("ParseSingleSequenceExample"); auto set_outputs = [&op](int num_context_sparse, int num_context_dense, |