diff options
author | Rachel Lim <rachelim@google.com> | 2018-10-04 13:01:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 13:12:57 -0700 |
commit | 7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (patch) | |
tree | 84087a64563d10c3390f991c6263c7fa2cc65b11 /tensorflow/core/ops | |
parent | 074ff471fefbcf3bfd49914ad80bd9f9751df363 (diff) |
[tf.data] Add a notion of `captured args` to MapDefun
PiperOrigin-RevId: 215788485
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r-- | tensorflow/core/ops/dataset_ops.cc | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 889a6a4640..ec22eee874 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -903,14 +903,18 @@ REGISTER_OP("ModelDataset") REGISTER_OP("MapDefun") .Input("arguments: Targuments") + .Input("captured_inputs: Tcaptured") .Output("output: output_types") .Attr("Targuments: list(type) >= 1") + .Attr("Tcaptured: list(type) >= 0 = []") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .Attr("f: func") .SetShapeFn([](shape_inference::InferenceContext* c) { std::vector<PartialTensorShape> output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + DataTypeVector t_args; + TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args)); if (output_shapes.size() != c->num_outputs()) { return errors::InvalidArgument( "`output_shapes` must be the same length as `output_types` (", @@ -918,10 +922,11 @@ REGISTER_OP("MapDefun") } int64 dim_zero = -1; - for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) { + for (size_t i = 0; i < t_args.size(); ++i) { if (c->Rank(c->input(i)) == 0) { return errors::InvalidArgument( - "Inputs must have rank at least 1. Input ", i, " has rank of 0"); + "Arguments must have rank at least 1. Input ", i, + " has rank of 0."); } auto dim_handle = c->Dim(c->input(i), 0); if (c->ValueKnown(dim_handle)) { @@ -929,7 +934,7 @@ REGISTER_OP("MapDefun") dim_zero = c->Value(dim_handle); } else if (c->Value(dim_handle) != dim_zero) { return errors::InvalidArgument( - "Inputs must have the same dimension 0."); + "Arguments must have the same dimension 0."); } } } |