aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/parsing_ops_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/parsing_ops_test.cc')
-rw-r--r--tensorflow/core/ops/parsing_ops_test.cc82
1 files changed, 82 insertions, 0 deletions
diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc
index 9121d7ae92..c65e66d1a8 100644
--- a/tensorflow/core/ops/parsing_ops_test.cc
+++ b/tensorflow/core/ops/parsing_ops_test.cc
@@ -143,6 +143,88 @@ TEST(ParsingOpsTest, ParseExample_ShapeFn) {
"?;?;?;?;?;?;?;?;?;?");
}
+TEST(ParsingOpsTest, ParseSequenceExample_ShapeFn) {
+ ShapeInferenceTestOp op("ParseSequenceExample");
+ auto set_outputs = [&op](int num_context_sparse, int num_context_dense,
+ int num_feature_list_sparse,
+ int num_feature_list_dense,
+ bool add_extra_shape = false) {
+ using NodeOutList = std::vector<NodeDefBuilder::NodeOut>;
+ using DataTypeList = std::vector<DataType>;
+ string string_in("test");
+ NodeDefBuilder::NodeOut node_in{"a", 0, DT_STRING};
+ TF_ASSERT_OK(
+ NodeDefBuilder("test", "ParseSequenceExample")
+ .Input("serialized", 0, DT_STRING)
+ .Input("debug_name", 0, DT_STRING)
+ .Input(NodeOutList(num_context_dense, node_in))
+ .Attr("Ncontext_sparse", num_context_sparse)
+ .Attr("Ncontext_dense", num_context_dense)
+ .Attr("Nfeature_list_sparse", num_feature_list_sparse)
+ .Attr("Nfeature_list_dense", num_feature_list_dense)
+ .Attr("feature_list_dense_missing_assumed_empty",
+ std::vector<string>(num_feature_list_dense, string_in))
+ .Attr("context_sparse_keys",
+ std::vector<string>(num_context_sparse, string_in))
+ .Attr("context_dense_keys",
+ std::vector<string>(num_context_dense, string_in))
+ .Attr("feature_list_sparse_keys",
+ std::vector<string>(num_feature_list_sparse, string_in))
+ .Attr("feature_list_dense_keys",
+ std::vector<string>(num_feature_list_dense, string_in))
+ .Attr("context_sparse_types",
+ DataTypeList(num_context_sparse, DT_FLOAT))
+ .Attr("context_dense_types",
+ DataTypeList(num_context_dense, DT_FLOAT))
+ .Attr("context_dense_shapes",
+ MakeDenseShapes(num_context_dense, add_extra_shape, 0))
+ .Attr("feature_list_sparse_types",
+ DataTypeList(num_feature_list_sparse, DT_FLOAT))
+ .Attr("feature_list_dense_types",
+ DataTypeList(num_feature_list_dense, DT_FLOAT))
+ .Attr("feature_list_dense_shapes",
+ MakeDenseShapes(num_feature_list_dense, add_extra_shape, 0))
+ .Finalize(&op.node_def));
+ };
+
+ // Verify inputs 'serialized' and 'debug_name'.
+ set_outputs(0, 0, 0, 0);
+ INFER_OK(op, "[?];[?]", "");
+ INFER_OK(op, "[8];[8]", "");
+ INFER_ERROR("must be rank 1", op, "[];[?]");
+ INFER_ERROR("must be rank 1", op, "[?];[]");
+
+ // context inputs with no feature_list inputs.
+ set_outputs(2 /* num_context_sparse */, 3 /* num_context_dense */, 0, 0);
+ INFER_OK(op, "[?];[?];?;?;?",
+ ("[?,2];[?,2];[?];[?];[2];[2];" // context sparse
+ "[d0_0,1];[d0_0,1,2];[d0_0,1,2,3]")); // context dense
+
+ // feature_list inputs with no context inputs.
+ set_outputs(0, 0, 2 /* num_feature_list_sparse */,
+ 3 /* num_feature_list_dense */);
+ INFER_OK(op, "[?];[?]",
+ ("[?,3];[?,3];[?];[?];[3];[3];" // feature_list sparse
+ "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3];" // feature_list dense
+ "[d0_0];[d0_0];[d0_0]")); // feature_list length
+
+ // Combine previous two test cases.
+ set_outputs(2, 3, 2, 3);
+ INFER_OK(op, "[7];[7];?;?;?",
+ ("[?,2];[?,2];[?];[?];[2];[2];" // context sparse
+ "[d0_0,1];[d0_0,1,2];[d0_0,1,2,3];" // context dense
+ "[?,3];[?,3];[?];[?];[3];[3];" // feature_list sparse
+ "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3];" // feature_list dense
+ "[d0_0];[d0_0];[d0_0]")); // feature_list length
+
+ // Confirm an error from ParseSequenceExampleAttrs.Init().
+ set_outputs(1, 1, 1, 1, true /* add_extra_shape */);
+ INFER_ERROR(
+ "num_context_dense (1) must match the size of context_dense_keys (1), "
+ "context_dense_types (1) and context_dense_shapes (2)",
+ op, "[?];[?];?");
+}
+
TEST(ParsingOpsTest, ParseSingleSequenceExample_ShapeFn) {
ShapeInferenceTestOp op("ParseSingleSequenceExample");
auto set_outputs = [&op](int num_context_sparse, int num_context_dense,