diff options
author | Yifei Feng <yifeif@google.com> | 2018-04-23 21:19:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-23 21:21:38 -0700 |
commit | 22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch) | |
tree | d16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/core/ops/dataset_ops.cc | |
parent | 24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff) |
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r-- | tensorflow/core/ops/dataset_ops.cc | 140 |
1 files changed, 121 insertions, 19 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 67c6c58fe2..4ba3f15ef0 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -148,7 +148,11 @@ REGISTER_OP("BytesProducedStatsDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle tag_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("LatencyStatsDataset") .Input("input_dataset: variant") @@ -156,7 +160,11 @@ REGISTER_OP("LatencyStatsDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle tag_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("SetStatsAggregatorDataset") .Input("input_dataset: variant") @@ -206,7 +214,12 @@ REGISTER_OP("PrefetchDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // buffer_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("ScanDataset") .Input("input_dataset: variant") @@ -290,7 +303,12 @@ REGISTER_OP("BatchDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // batch_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); // TODO(mrry): move SlideDataset to contrib in the future. REGISTER_OP("SlideDataset") @@ -300,7 +318,13 @@ REGISTER_OP("SlideDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // window_size and stride should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("PaddedBatchDataset") .Input("input_dataset: variant") @@ -330,7 +354,14 @@ REGISTER_OP("DenseToSparseBatchDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // batch_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + // row_shape should be a 1-D vector. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("RangeDataset") .Input("start: int64") @@ -341,7 +372,14 @@ REGISTER_OP("RangeDataset") .Attr("output_shapes: list(shape) >= 1") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // start, stop, and step should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("RandomDataset") .Input("seed: int64") @@ -351,7 +389,13 @@ REGISTER_OP("RandomDataset") .Attr("output_shapes: list(shape) >= 1") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // buffer_size, seed, and seed2 should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("ShuffleDataset") .Input("input_dataset: variant") @@ -362,7 +406,14 @@ REGISTER_OP("ShuffleDataset") .Attr("reshuffle_each_iteration: bool = true") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // buffer_size, seed, and seed2 should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("ShuffleAndRepeatDataset") .Input("input_dataset: variant") @@ -373,7 +424,15 @@ REGISTER_OP("ShuffleAndRepeatDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // buffer_size, seed, seed2, and count should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("CacheDataset") .Input("input_dataset: variant") @@ -381,7 +440,12 @@ REGISTER_OP("CacheDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // filename should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("TextLineDataset") .Input("filenames: string") @@ -390,10 +454,16 @@ REGISTER_OP("TextLineDataset") .Output("handle: variant") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): validate - // that `filenames` is - // a scalar or a - // vector. + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + return shape_inference::ScalarShape(c); + // `compression_type` could only be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + // `buffer_size` could only be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + }); REGISTER_OP("SqlDataset") .Input("driver_name: string") @@ -404,7 +474,14 @@ REGISTER_OP("SqlDataset") .Attr("output_shapes: list(shape) >= 1") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // driver_name, data_source_name, and query should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("FixedLengthRecordDataset") .Input("filenames: string") @@ -415,7 +492,18 @@ REGISTER_OP("FixedLengthRecordDataset") .Output("handle: variant") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + // header_bytes, record_bytes, footer_bytes, buffer_size should be + // scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("TFRecordDataset") .Input("filenames: string") @@ -424,7 +512,16 @@ REGISTER_OP("TFRecordDataset") .Output("handle: variant") .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + // `compression_type` could only be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + // `buffer_size` could only be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("Iterator") .Output("handle: resource") @@ -540,7 +637,12 @@ REGISTER_OP("PrependFromQueueAndPaddedBatchDataset") // length of `output_types` is `N`, the `output_shapes` are // (as far as possible to tell statically) compatible with `padded_shapes`, // and that `padding_values` are all scalars. - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // batch_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("EnqueueInQueueDataset") .Input("queue: variant") |