aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-04 13:01:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 13:12:57 -0700
commit7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (patch)
tree84087a64563d10c3390f991c6263c7fa2cc65b11 /tensorflow/core/ops
parent074ff471fefbcf3bfd49914ad80bd9f9751df363 (diff)
[tf.data] Add a notion of `captured args` to MapDefun
PiperOrigin-RevId: 215788485
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r--tensorflow/core/ops/dataset_ops.cc11
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 889a6a4640..ec22eee874 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -903,14 +903,18 @@ REGISTER_OP("ModelDataset")
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
+ .Input("captured_inputs: Tcaptured")
.Output("output: output_types")
.Attr("Targuments: list(type) >= 1")
+ .Attr("Tcaptured: list(type) >= 0 = []")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ DataTypeVector t_args;
+ TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as `output_types` (",
@@ -918,10 +922,11 @@ REGISTER_OP("MapDefun")
}
int64 dim_zero = -1;
- for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+ for (size_t i = 0; i < t_args.size(); ++i) {
if (c->Rank(c->input(i)) == 0) {
return errors::InvalidArgument(
- "Inputs must have rank at least 1. Input ", i, " has rank of 0");
+ "Arguments must have rank at least 1. Input ", i,
+ " has rank of 0.");
}
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
@@ -929,7 +934,7 @@ REGISTER_OP("MapDefun")
dim_zero = c->Value(dim_handle);
} else if (c->Value(dim_handle) != dim_zero) {
return errors::InvalidArgument(
- "Inputs must have the same dimension 0.");
+ "Arguments must have the same dimension 0.");
}
}
}