diff options
author | Alexandre Passos <apassos@google.com> | 2018-07-23 15:11:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-23 15:32:59 -0700 |
commit | f85d825500357603afb7a02d2c88ad306ee43006 (patch) | |
tree | a1e90f4af08265f8434a80285ebe4a7c6f8b3e0e | |
parent | 931a3054d2c13c3438fc58978b3463a0bd268aee (diff) |
Allow differentiating tfe.defun functions which contain conds.
PiperOrigin-RevId: 205732423
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/framework/function.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/partitioned_function_ops.cc | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 13 |
6 files changed, 38 insertions, 11 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 44291b0b20..d1fd930d25 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -717,7 +717,8 @@ Status DirectSession::Run(const RunOptions& run_options, // Receive outputs. if (outputs) { std::vector<Tensor> sorted_outputs; - const Status s = call_frame.ConsumeRetvals(&sorted_outputs); + const Status s = call_frame.ConsumeRetvals( + &sorted_outputs, /* allow_dead_tensors = */ false); if (errors::IsInternal(s)) { return errors::InvalidArgument(s.error_message()); } else if (!s.ok()) { diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index a93cfa2ec5..54bbe84b57 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -746,6 +746,8 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, rets_alloc_attrs.push_back(ret_alloc_attrs); } + bool allow_dead_tensors = opts.allow_dead_tensors; + // The ProcFLR sends the arguments to the function from the source_device to // the target_device. So here we receive those arguments. Similarly, when the // computation is done and stored in *rets, we send the return values back @@ -756,7 +758,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, device_context, args_alloc_attrs, rendezvous, remote_args, [frame, remote_args, item, source_device, target_device, target_incarnation, rendezvous, device_context, rets, done, exec_args, - rets_alloc_attrs](const Status& status) { + rets_alloc_attrs, allow_dead_tensors](const Status& status) { Status s = status; if (s.ok()) { s = frame->SetArgs(*remote_args); @@ -769,13 +771,13 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, return; } item->exec->RunAsync( - *exec_args, - [frame, rets, done, source_device, target_device, - target_incarnation, rendezvous, device_context, remote_args, - exec_args, rets_alloc_attrs](const Status& status) { + *exec_args, [frame, rets, done, source_device, target_device, + target_incarnation, rendezvous, device_context, + remote_args, exec_args, rets_alloc_attrs, + allow_dead_tensors](const Status& status) { Status s = status; if (s.ok()) { - s = frame->ConsumeRetvals(rets); + s = frame->ConsumeRetvals(rets, allow_dead_tensors); } delete frame; if (!s.ok()) { @@ -859,14 +861,15 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, return; } + bool allow_dead_tensors = opts.allow_dead_tensors; item->exec->RunAsync( // Executor args *exec_args, // Done callback. - [frame, rets, done, exec_args](const Status& status) { + [frame, rets, done, exec_args, allow_dead_tensors](const Status& status) { Status s = status; if (s.ok()) { - s = frame->ConsumeRetvals(rets); + s = frame->ConsumeRetvals(rets, allow_dead_tensors); } delete frame; delete exec_args; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 88d9d65f5a..57bcc0f513 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -865,12 +865,15 @@ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const { return Status::OK(); } -Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) { +Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets, + bool allow_dead_tensors) { rets->clear(); rets->reserve(rets_.size()); for (size_t i = 0; i < rets_.size(); ++i) { if (rets_[i].has_val) { rets->emplace_back(std::move(rets_[i].val)); + } else if (allow_dead_tensors) { + rets->emplace_back(); } else { return errors::Internal("Retval[", i, "] does not have value"); } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 8e607b927c..5da9af7db3 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -261,7 +261,10 @@ class FunctionCallFrame : public CallFrameInterface { // Caller methods. Status SetArgs(gtl::ArraySlice<Tensor> args); Status GetRetvals(std::vector<Tensor>* rets) const; - Status ConsumeRetvals(std::vector<Tensor>* rets); + + // Moves the return values from the frame to rets. If allow_dead_tensors is + // false it will fail if any of the retvals do not have a value. + Status ConsumeRetvals(std::vector<Tensor>* rets, bool allow_dead_tensors); size_t num_args() const override { return arg_types_.size(); } size_t num_retvals() const override { return ret_types_.size(); } @@ -510,6 +513,9 @@ class FunctionLibraryRuntime { // If true, we create a new IntraProcessRendezvous, else use the existing // one. bool create_rendezvous = false; + + // If True, allow returning dead tensors. + bool allow_dead_tensors = false; }; typedef std::function<void(const Status&)> DoneCallback; virtual void Run(const Options& opts, Handle handle, diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index b5c6ba1da3..a7a9609c21 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -330,6 +330,7 @@ class PartitionedCallOp : public AsyncOpKernel { // using device-specific threadpools when available. opts.runner = ctx->runner(); opts.source_device = local_device_name_; + opts.allow_dead_tensors = true; // TODO(akshayka): Accommodate the multiple-worker scenario by adding the // constructed rendezvous to a rendezvous manager. Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr()); diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index e6592b2e37..2e86563a7d 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -213,6 +213,19 @@ class FunctionTest(test.TestCase): self.assertEqual(fn_op.output_shapes, None) self.assertAllEqual(fn_op(x, x), None) + @test_util.run_in_graph_and_eager_modes() + def testDefunCondGradient(self): + + @function.defun + def f(x): + return control_flow_ops.cond(x > 0.5, lambda: 2 * x, lambda: 3 * x) + + with backprop.GradientTape() as t: + x = constant_op.constant(1.0) + t.watch(x) + y = f(x) + self.assertAllEqual(self.evaluate(t.gradient(y, x)), 2.0) + def testDefunCapturedInt32(self): x = constant_op.constant(1, dtype=dtypes.int32) |