diff options
Diffstat (limited to 'tensorflow/core/ops/functional_ops.cc')
-rw-r--r-- | tensorflow/core/ops/functional_ops.cc | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index 88553dff93..5f262db2ce 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -31,11 +31,23 @@ REGISTER_OP("SymbolicGradient") if (c->num_inputs() < c->num_outputs()) { return errors::InvalidArgument("len(inputs) < len(outputs)"); } + std::vector<DataType> types; + TF_RETURN_IF_ERROR(c->GetAttr("Tin", &types)); // Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of // (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its // outputs (dx, dy, dz) are the same as (x, y, z). for (int i = 0; i < c->num_outputs(); ++i) { - c->set_output(i, c->input(i)); + if (types[i] == DT_RESOURCE) { + const std::vector<shape_inference::ShapeAndType>* handle_type = + c->input_handle_shapes_and_types(i); + if (handle_type != nullptr) { + c->set_output(i, handle_type->at(0).shape); + } else { + c->set_output(i, c->UnknownShape()); + } + } else { + c->set_output(i, c->input(i)); + } } return Status::OK(); }); @@ -83,7 +95,7 @@ REGISTER_OP("If") .Output("output: Tout") .Attr("Tcond: type") .Attr("Tin: list(type) >= 0") - .Attr("Tout: list(type)") + .Attr("Tout: list(type) >= 0") .Attr("then_branch: func") .Attr("else_branch: func") .SetShapeFn(shape_inference::UnknownShape); @@ -145,7 +157,6 @@ REGISTER_OP("For") .Attr("body: func") .SetShapeFn(shape_inference::UnknownShape); -// TODO(b/73826847, b/37549631) Mark as stateful. REGISTER_OP("PartitionedCall") .Input("args: Tin") .Output("output: Tout") @@ -154,6 +165,15 @@ REGISTER_OP("PartitionedCall") .Attr("f: func") .SetShapeFn(shape_inference::UnknownShape); +REGISTER_OP("StatefulPartitionedCall") + .Input("args: Tin") + .Output("output: Tout") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .Attr("f: func") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape); + // This op is used as a placeholder in If branch functions. It doesn't provide a // valid output when run, so must either be removed (e.g. replaced with a // function input) or guaranteed not to be used (e.g. if mirroring an |