aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/functional_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/functional_ops.cc')
-rw-r--r--tensorflow/core/ops/functional_ops.cc18
1 files changed, 15 insertions, 3 deletions
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index ec22ce4177..5f262db2ce 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -40,7 +40,11 @@ REGISTER_OP("SymbolicGradient")
if (types[i] == DT_RESOURCE) {
const std::vector<shape_inference::ShapeAndType>* handle_type =
c->input_handle_shapes_and_types(i);
- c->set_output(i, handle_type->at(0).shape);
+ if (handle_type != nullptr) {
+ c->set_output(i, handle_type->at(0).shape);
+ } else {
+ c->set_output(i, c->UnknownShape());
+ }
} else {
c->set_output(i, c->input(i));
}
@@ -91,7 +95,7 @@ REGISTER_OP("If")
.Output("output: Tout")
.Attr("Tcond: type")
.Attr("Tin: list(type) >= 0")
- .Attr("Tout: list(type)")
+ .Attr("Tout: list(type) >= 0")
.Attr("then_branch: func")
.Attr("else_branch: func")
.SetShapeFn(shape_inference::UnknownShape);
@@ -153,7 +157,6 @@ REGISTER_OP("For")
.Attr("body: func")
.SetShapeFn(shape_inference::UnknownShape);
-// TODO(b/73826847, b/37549631) Mark as stateful.
REGISTER_OP("PartitionedCall")
.Input("args: Tin")
.Output("output: Tout")
@@ -162,6 +165,15 @@ REGISTER_OP("PartitionedCall")
.Attr("f: func")
.SetShapeFn(shape_inference::UnknownShape);
+REGISTER_OP("StatefulPartitionedCall")
+ .Input("args: Tin")
+ .Output("output: Tout")
+ .Attr("Tin: list(type) >= 0")
+ .Attr("Tout: list(type) >= 0")
+ .Attr("f: func")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape);
+
// This op is used as a placeholder in If branch functions. It doesn't provide a
// valid output when run, so must either be removed (e.g. replaced with a
// function input) or guaranteed not to be used (e.g. if mirroring an