aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-07-23 15:11:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 15:32:59 -0700
commitf85d825500357603afb7a02d2c88ad306ee43006 (patch)
treea1e90f4af08265f8434a80285ebe4a7c6f8b3e0e
parent931a3054d2c13c3438fc58978b3463a0bd268aee (diff)
Allow differentiating tfe.defun functions which contain conds.
PiperOrigin-RevId: 205732423
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc3
-rw-r--r--tensorflow/core/common_runtime/function.cc19
-rw-r--r--tensorflow/core/framework/function.cc5
-rw-r--r--tensorflow/core/framework/function.h8
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc1
-rw-r--r--tensorflow/python/eager/function_test.py13
6 files changed, 38 insertions, 11 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 44291b0b20..d1fd930d25 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -717,7 +717,8 @@ Status DirectSession::Run(const RunOptions& run_options,
// Receive outputs.
if (outputs) {
std::vector<Tensor> sorted_outputs;
- const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
+ const Status s = call_frame.ConsumeRetvals(
+ &sorted_outputs, /* allow_dead_tensors = */ false);
if (errors::IsInternal(s)) {
return errors::InvalidArgument(s.error_message());
} else if (!s.ok()) {
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index a93cfa2ec5..54bbe84b57 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -746,6 +746,8 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
rets_alloc_attrs.push_back(ret_alloc_attrs);
}
+ bool allow_dead_tensors = opts.allow_dead_tensors;
+
// The ProcFLR sends the arguments to the function from the source_device to
// the target_device. So here we receive those arguments. Similarly, when the
// computation is done and stored in *rets, we send the return values back
@@ -756,7 +758,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
device_context, args_alloc_attrs, rendezvous, remote_args,
[frame, remote_args, item, source_device, target_device,
target_incarnation, rendezvous, device_context, rets, done, exec_args,
- rets_alloc_attrs](const Status& status) {
+ rets_alloc_attrs, allow_dead_tensors](const Status& status) {
Status s = status;
if (s.ok()) {
s = frame->SetArgs(*remote_args);
@@ -769,13 +771,13 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
return;
}
item->exec->RunAsync(
- *exec_args,
- [frame, rets, done, source_device, target_device,
- target_incarnation, rendezvous, device_context, remote_args,
- exec_args, rets_alloc_attrs](const Status& status) {
+ *exec_args, [frame, rets, done, source_device, target_device,
+ target_incarnation, rendezvous, device_context,
+ remote_args, exec_args, rets_alloc_attrs,
+ allow_dead_tensors](const Status& status) {
Status s = status;
if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
+ s = frame->ConsumeRetvals(rets, allow_dead_tensors);
}
delete frame;
if (!s.ok()) {
@@ -859,14 +861,15 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
return;
}
+ bool allow_dead_tensors = opts.allow_dead_tensors;
item->exec->RunAsync(
// Executor args
*exec_args,
// Done callback.
- [frame, rets, done, exec_args](const Status& status) {
+ [frame, rets, done, exec_args, allow_dead_tensors](const Status& status) {
Status s = status;
if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
+ s = frame->ConsumeRetvals(rets, allow_dead_tensors);
}
delete frame;
delete exec_args;
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 88d9d65f5a..57bcc0f513 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -865,12 +865,15 @@ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
return Status::OK();
}
-Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) {
+Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
+ bool allow_dead_tensors) {
rets->clear();
rets->reserve(rets_.size());
for (size_t i = 0; i < rets_.size(); ++i) {
if (rets_[i].has_val) {
rets->emplace_back(std::move(rets_[i].val));
+ } else if (allow_dead_tensors) {
+ rets->emplace_back();
} else {
return errors::Internal("Retval[", i, "] does not have value");
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 8e607b927c..5da9af7db3 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -261,7 +261,10 @@ class FunctionCallFrame : public CallFrameInterface {
// Caller methods.
Status SetArgs(gtl::ArraySlice<Tensor> args);
Status GetRetvals(std::vector<Tensor>* rets) const;
- Status ConsumeRetvals(std::vector<Tensor>* rets);
+
+ // Moves the return values from the frame to rets. If allow_dead_tensors is
+ // false it will fail if any of the retvals do not have a value.
+ Status ConsumeRetvals(std::vector<Tensor>* rets, bool allow_dead_tensors);
size_t num_args() const override { return arg_types_.size(); }
size_t num_retvals() const override { return ret_types_.size(); }
@@ -510,6 +513,9 @@ class FunctionLibraryRuntime {
// If true, we create a new IntraProcessRendezvous, else use the existing
// one.
bool create_rendezvous = false;
+
+ // If True, allow returning dead tensors.
+ bool allow_dead_tensors = false;
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void Run(const Options& opts, Handle handle,
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index b5c6ba1da3..a7a9609c21 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -330,6 +330,7 @@ class PartitionedCallOp : public AsyncOpKernel {
// using device-specific threadpools when available.
opts.runner = ctx->runner();
opts.source_device = local_device_name_;
+ opts.allow_dead_tensors = true;
// TODO(akshayka): Accommodate the multiple-worker scenario by adding the
// constructed rendezvous to a rendezvous manager.
Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr());
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index e6592b2e37..2e86563a7d 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -213,6 +213,19 @@ class FunctionTest(test.TestCase):
self.assertEqual(fn_op.output_shapes, None)
self.assertAllEqual(fn_op(x, x), None)
+ @test_util.run_in_graph_and_eager_modes()
+ def testDefunCondGradient(self):
+
+ @function.defun
+ def f(x):
+ return control_flow_ops.cond(x > 0.5, lambda: 2 * x, lambda: 3 * x)
+
+ with backprop.GradientTape() as t:
+ x = constant_op.constant(1.0)
+ t.watch(x)
+ y = f(x)
+ self.assertAllEqual(self.evaluate(t.gradient(y, x)), 2.0)
+
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)