aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/dataset_ops.cc
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-23 21:19:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 21:21:38 -0700
commit22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch)
treed16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/core/ops/dataset_ops.cc
parent24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff)
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r--tensorflow/core/ops/dataset_ops.cc140
1 files changed, 121 insertions, 19 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 67c6c58fe2..4ba3f15ef0 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -148,7 +148,11 @@ REGISTER_OP("BytesProducedStatsDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle tag_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("LatencyStatsDataset")
.Input("input_dataset: variant")
@@ -156,7 +160,11 @@ REGISTER_OP("LatencyStatsDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle tag_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
@@ -206,7 +214,12 @@ REGISTER_OP("PrefetchDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // buffer_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("ScanDataset")
.Input("input_dataset: variant")
@@ -290,7 +303,12 @@ REGISTER_OP("BatchDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
// TODO(mrry): move SlideDataset to contrib in the future.
REGISTER_OP("SlideDataset")
@@ -300,7 +318,13 @@ REGISTER_OP("SlideDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // window_size and stride should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("PaddedBatchDataset")
.Input("input_dataset: variant")
@@ -330,7 +354,14 @@ REGISTER_OP("DenseToSparseBatchDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // row_shape should be a 1-D vector.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("RangeDataset")
.Input("start: int64")
@@ -341,7 +372,14 @@ REGISTER_OP("RangeDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // start, stop, and step should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("RandomDataset")
.Input("seed: int64")
@@ -351,7 +389,13 @@ REGISTER_OP("RandomDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // buffer_size, seed, and seed2 should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("ShuffleDataset")
.Input("input_dataset: variant")
@@ -362,7 +406,14 @@ REGISTER_OP("ShuffleDataset")
.Attr("reshuffle_each_iteration: bool = true")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // buffer_size, seed, and seed2 should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("ShuffleAndRepeatDataset")
.Input("input_dataset: variant")
@@ -373,7 +424,15 @@ REGISTER_OP("ShuffleAndRepeatDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // buffer_size, seed, seed2, and count should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("CacheDataset")
.Input("input_dataset: variant")
@@ -381,7 +440,12 @@ REGISTER_OP("CacheDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // filename should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("TextLineDataset")
.Input("filenames: string")
@@ -390,10 +454,16 @@ REGISTER_OP("TextLineDataset")
.Output("handle: variant")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): validate
- // that `filenames` is
- // a scalar or a
- // vector.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ // `compression_type` could only be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // `buffer_size` could only be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ });
REGISTER_OP("SqlDataset")
.Input("driver_name: string")
@@ -404,7 +474,14 @@ REGISTER_OP("SqlDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // driver_name, data_source_name, and query should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("FixedLengthRecordDataset")
.Input("filenames: string")
@@ -415,7 +492,18 @@ REGISTER_OP("FixedLengthRecordDataset")
.Output("handle: variant")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ // header_bytes, record_bytes, footer_bytes, buffer_size should be
+ // scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("TFRecordDataset")
.Input("filenames: string")
@@ -424,7 +512,16 @@ REGISTER_OP("TFRecordDataset")
.Output("handle: variant")
.SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
// stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ // `compression_type` could only be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // `buffer_size` could only be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("Iterator")
.Output("handle: resource")
@@ -540,7 +637,12 @@ REGISTER_OP("PrependFromQueueAndPaddedBatchDataset")
// length of `output_types` is `N`, the `output_shapes` are
// (as far as possible to tell statically) compatible with `padded_shapes`,
// and that `padding_values` are all scalars.
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("EnqueueInQueueDataset")
.Input("queue: variant")