diff options
Diffstat (limited to 'tensorflow/core/ops/functional_ops.cc')
-rw-r--r-- | tensorflow/core/ops/functional_ops.cc | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index ec22ce4177..5f262db2ce 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -40,7 +40,11 @@ REGISTER_OP("SymbolicGradient") if (types[i] == DT_RESOURCE) { const std::vector<shape_inference::ShapeAndType>* handle_type = c->input_handle_shapes_and_types(i); - c->set_output(i, handle_type->at(0).shape); + if (handle_type != nullptr) { + c->set_output(i, handle_type->at(0).shape); + } else { + c->set_output(i, c->UnknownShape()); + } } else { c->set_output(i, c->input(i)); } @@ -91,7 +95,7 @@ REGISTER_OP("If") .Output("output: Tout") .Attr("Tcond: type") .Attr("Tin: list(type) >= 0") - .Attr("Tout: list(type)") + .Attr("Tout: list(type) >= 0") .Attr("then_branch: func") .Attr("else_branch: func") .SetShapeFn(shape_inference::UnknownShape); @@ -153,7 +157,6 @@ REGISTER_OP("For") .Attr("body: func") .SetShapeFn(shape_inference::UnknownShape); -// TODO(b/73826847, b/37549631) Mark as stateful. REGISTER_OP("PartitionedCall") .Input("args: Tin") .Output("output: Tout") @@ -162,6 +165,15 @@ REGISTER_OP("PartitionedCall") .Attr("f: func") .SetShapeFn(shape_inference::UnknownShape); +REGISTER_OP("StatefulPartitionedCall") + .Input("args: Tin") + .Output("output: Tout") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .Attr("f: func") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape); + // This op is used as a placeholder in If branch functions. It doesn't provide a // valid output when run, so must either be removed (e.g. replaced with a // function input) or guaranteed not to be used (e.g. if mirroring an |