aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/functional_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/functional_ops.cc')
-rw-r--r--tensorflow/core/ops/functional_ops.cc26
1 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index 88553dff93..5f262db2ce 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -31,11 +31,23 @@ REGISTER_OP("SymbolicGradient")
if (c->num_inputs() < c->num_outputs()) {
return errors::InvalidArgument("len(inputs) < len(outputs)");
}
+ std::vector<DataType> types;
+ TF_RETURN_IF_ERROR(c->GetAttr("Tin", &types));
// Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of
// (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its
// outputs (dx, dy, dz) are the same as (x, y, z).
for (int i = 0; i < c->num_outputs(); ++i) {
- c->set_output(i, c->input(i));
+ if (types[i] == DT_RESOURCE) {
+ const std::vector<shape_inference::ShapeAndType>* handle_type =
+ c->input_handle_shapes_and_types(i);
+ if (handle_type != nullptr) {
+ c->set_output(i, handle_type->at(0).shape);
+ } else {
+ c->set_output(i, c->UnknownShape());
+ }
+ } else {
+ c->set_output(i, c->input(i));
+ }
}
return Status::OK();
});
@@ -83,7 +95,7 @@ REGISTER_OP("If")
.Output("output: Tout")
.Attr("Tcond: type")
.Attr("Tin: list(type) >= 0")
- .Attr("Tout: list(type)")
+ .Attr("Tout: list(type) >= 0")
.Attr("then_branch: func")
.Attr("else_branch: func")
.SetShapeFn(shape_inference::UnknownShape);
@@ -145,7 +157,6 @@ REGISTER_OP("For")
.Attr("body: func")
.SetShapeFn(shape_inference::UnknownShape);
-// TODO(b/73826847, b/37549631) Mark as stateful.
REGISTER_OP("PartitionedCall")
.Input("args: Tin")
.Output("output: Tout")
@@ -154,6 +165,15 @@ REGISTER_OP("PartitionedCall")
.Attr("f: func")
.SetShapeFn(shape_inference::UnknownShape);
+REGISTER_OP("StatefulPartitionedCall")
+ .Input("args: Tin")
+ .Output("output: Tout")
+ .Attr("Tin: list(type) >= 0")
+ .Attr("Tout: list(type) >= 0")
+ .Attr("f: func")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape);
+
// This op is used as a placeholder in If branch functions. It doesn't provide a
// valid output when run, so must either be removed (e.g. replaced with a
// function input) or guaranteed not to be used (e.g. if mirroring an