aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-11-08 13:44:26 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:37 -0800
commit2545c4e93b7c1ee21ddb3666580ff4922630d974 (patch)
treeed3da37ca4f30f365822785f2e8b3aa2bf26388f
parentfd52578963fdc3474be30c38fa9027c1c407301b (diff)
Moves imperative_grad to C
Neutral-to-positive on all benchmarks. Also reduces overhead of should_record. PiperOrigin-RevId: 175057104
-rw-r--r--tensorflow/c/eager/BUILD1
-rw-r--r--tensorflow/c/eager/tape.cc312
-rw-r--r--tensorflow/c/eager/tape.h58
-rw-r--r--tensorflow/python/eager/BUILD7
-rw-r--r--tensorflow/python/eager/backprop.py14
-rw-r--r--tensorflow/python/eager/backprop_test.py57
-rw-r--r--tensorflow/python/eager/imperative_grad.py194
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc8
-rw-r--r--tensorflow/python/eager/pywrap_tensor.h25
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h13
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc313
-rw-r--r--tensorflow/python/eager/tape.py12
-rw-r--r--tensorflow/python/eager/tape_test.py20
-rw-r--r--tensorflow/python/pywrap_tfe.i4
14 files changed, 702 insertions, 336 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index c77896b80b..74e94be8d6 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -39,6 +39,7 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api_internal",
hdrs = ["c_api_internal.h"],
+ visibility = ["//tensorflow:internal"],
deps = [
":c_api",
":runtime",
diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc
index 464612a81e..459499bb69 100644
--- a/tensorflow/c/eager/tape.cc
+++ b/tensorflow/c/eager/tape.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <unordered_set>
+
#include "tensorflow/c/eager/tape.h"
namespace tensorflow {
@@ -94,8 +96,314 @@ void GradientTape::DeleteTrace(int64 tensor_id) {
op_tape_.erase(op_it);
}
-std::pair<TensorTape, OpTape> GradientTape::Export() {
- return {std::move(tensor_tape_), std::move(op_tape_)};
+// Terminology:
+//
+// - op: a possibly composite operation, which has an entry in the tape
+// - target: dy in dx/dy
+// - source: dx in dx/dy
+// - tensor: one of the many inputs or outputs of an operation
+//
+// Below here we do the gradient algorithm. It works as follows:
+//
+// First we filter the tape to just the subset of operations we want to
+// differentiate. In the process of doing so we count how many times each Tensor
+// is used as an input to an op (so we know when we're done computing gradients
+// for that Tensor). We also count, for each tape entry, how many of its output
+// Tensors need gradients to be computed (Tensors which are not used do not need
+// any gradients to be computed).
+//
+// Finally, we start a backprop stack with a set of tape entries for which we
+// have all gradients available. This set usually is a subset of the set of
+// targets (not all since targets which have outputs in the tape will not have
+// gradients available initially).
+//
+// Then we repeatedly pop an entry from the stack, run its backprop, and update
+// the gradients of its inputs. Once we have computed all gradients for a single
+// input we can mark this input as done, and this can trigger adding an entry to
+// the stack if all outputs of that entry are now done.
+//
+// When the stack is empty we have gradients for all tensors we're interested
+// in.
+
+struct BackpropInitialState {
+ OpTape op_tape;
+
+ // Map from tensor ID to how many references still exist for this tensor in
+ // the tape.
+ std::unordered_map<int64, int64> tensor_usage_counts;
+
+ // Maps from op ID to how many output tensors of this op still need to have
+ // their gradients computed.
+ std::unordered_map<int64, int64> op_missing_tensor;
+};
+
+BackpropInitialState PrepareBackprop(
+ gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
+ OpTape op_tape, const std::unordered_set<int64>& sources_set) {
+ std::vector<int64> tensor_stack;
+ tensor_stack.reserve(target.size());
+ for (auto t : target) {
+ tensor_stack.push_back(t);
+ }
+ BackpropInitialState result;
+ while (!tensor_stack.empty()) {
+ int64 tensor_id = tensor_stack.back();
+ tensor_stack.pop_back();
+ auto op_id_it = tensor_tape.find(tensor_id);
+ if (op_id_it == tensor_tape.end()) {
+ continue;
+ }
+ int64 op_id = op_id_it->second;
+ auto op_it = op_tape.find(op_id);
+ auto result_op_it = result.op_tape.find(op_id);
+ if (op_id == -1 || op_it == op_tape.end() ||
+ result_op_it != result.op_tape.end()) {
+ continue;
+ }
+ CHECK(result.op_tape.emplace(op_id, op_it->second).second);
+ for (auto it : op_it->second.input_tensor_id) {
+ auto count_it = result.tensor_usage_counts.find(it);
+ if (count_it != result.tensor_usage_counts.end()) {
+ count_it->second++;
+ } else {
+ result.tensor_usage_counts[it] = 1;
+ if (sources_set.find(it) == sources_set.end() &&
+ tensor_tape.find(it) != tensor_tape.end()) {
+ tensor_stack.push_back(it);
+ }
+ }
+ }
+ op_tape.erase(op_it);
+ }
+ for (auto& pair : result.tensor_usage_counts) {
+ auto it = tensor_tape.find(pair.first);
+ if (it != tensor_tape.end() && it->second != -1) {
+ result.op_missing_tensor[it->second] += 1;
+ }
+ }
+ // Call destructors for all unneeded gradient functions.
+ for (const auto& op_pair : op_tape) {
+ op_pair.second.backward_function_deleter();
+ }
+ return result;
+}
+
+std::vector<int64> InitialStack(
+ const OpTape& op_tape,
+ const std::unordered_map<int64, int64>& op_missing_tensor) {
+ std::vector<int64> result;
+ for (auto& op_entry : op_tape) {
+ if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
+ result.push_back(op_entry.first);
+ }
+ }
+ return result;
+}
+
+Status InitialGradients(const VSpace& vspace, gtl::ArraySlice<void*> target,
+ gtl::ArraySlice<void*> output_gradients,
+ std::unordered_map<int64, int64> tensor_usage_counts,
+ std::unordered_map<int64, std::vector<void*>>* result) {
+ for (int i = 0; i < target.size(); ++i) {
+ int64 id = vspace.TensorId(target[i]);
+ if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
+ if (!output_gradients.empty() && output_gradients[i] != nullptr) {
+ // TODO(apassos) figure out how to print debugging information here.
+ return errors::InvalidArgument(
+ "A gradient was provided for a tensor which is used as part of the "
+ "computation.");
+ }
+ } else {
+ if (output_gradients.empty() || output_gradients[i] == nullptr) {
+ (*result)[id].push_back(vspace.OnesLike(target[i]));
+ } else {
+ (*result)[id].push_back(output_gradients[i]);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// If over kMinAggregateCount gradients are accumulated and the total
+// memory consumption is over kMinAggregateBytes, do an early aggregation
+// so as to release the gradient tensor to save memory.
+static const int kMinAggregateCount = 4;
+static const int kMinAggregateBytes = 128 * 1024 * 1024;
+
+Status GradientTape::Gradient(const VSpace& vspace,
+ gtl::ArraySlice<void*> target,
+ gtl::ArraySlice<void*> sources,
+ gtl::ArraySlice<void*> output_gradients,
+ std::vector<void*>* result) {
+ std::vector<int64> id_sources;
+ id_sources.reserve(sources.size());
+ for (void* s : sources) {
+ id_sources.push_back(vspace.TensorId(s));
+ }
+ std::unordered_set<int64> sources_set(id_sources.begin(), id_sources.end());
+ std::vector<int64> id_targets;
+ id_sources.reserve(target.size());
+ for (void* t : target) {
+ id_targets.push_back(vspace.TensorId(t));
+ }
+ BackpropInitialState state = PrepareBackprop(
+ id_targets, tensor_tape_, std::move(op_tape_), sources_set);
+ std::vector<int64> op_stack =
+ InitialStack(state.op_tape, state.op_missing_tensor);
+ std::unordered_map<int64, std::vector<void*>> gradients;
+ Status s = InitialGradients(vspace, target, output_gradients,
+ state.tensor_usage_counts, &gradients);
+ auto cleanup = [&state]() {
+ // Release all backprop functions
+ for (const auto& pair : state.op_tape) {
+ pair.second.backward_function_deleter();
+ }
+ };
+ if (!s.ok()) {
+ cleanup();
+ return s;
+ }
+ std::unordered_map<int64, int64> gradients_size;
+ // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
+ // time, for better CPU backprop performance.
+ VLOG(1) << "Initial stack:";
+ if (VLOG_IS_ON(1)) {
+ for (auto t : op_stack) {
+ VLOG(1) << " " << t;
+ }
+ }
+ std::unordered_map<string, std::unordered_set<int>>
+ functions_accept_none_for_indices({
+ {"SoftmaxCrossEntropyWithLogits", {1}},
+ {"FusedBatchNorm", {1, 2, 3, 4}},
+ });
+ while (!op_stack.empty()) {
+ const int64 op = op_stack.back();
+ VLOG(1) << "Popped " << op;
+ op_stack.pop_back();
+ auto op_it = state.op_tape.find(op);
+ if (op_it == state.op_tape.end()) {
+ // It is possible for ops to end up on the stack if they are unrelated to
+ // the target; we should just skip them.
+ continue;
+ }
+ auto trace = std::move(op_it->second);
+ state.op_tape.erase(op_it);
+ std::vector<void*> out_gradients;
+ out_gradients.reserve(trace.output_tensor_info.size());
+ for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
+ const int64 id = trace.output_tensor_info[i].id;
+ auto grad_it = gradients.find(id);
+ if (grad_it == gradients.end()) {
+ auto func_name_it =
+ functions_accept_none_for_indices.find(trace.op_type);
+ if (func_name_it != functions_accept_none_for_indices.end() &&
+ func_name_it->second.find(i) != func_name_it->second.end()) {
+ out_gradients.push_back(nullptr);
+ } else {
+ out_gradients.push_back(
+ vspace.Zeros(trace.output_tensor_info[i].shape,
+ trace.output_tensor_info[i].dtype));
+ }
+ } else {
+ out_gradients.push_back(vspace.AggregateGradients(grad_it->second));
+ if (sources_set.find(grad_it->first) == sources_set.end()) {
+ gradients.erase(grad_it);
+ }
+ }
+ }
+ std::vector<void*> in_gradients;
+ Status s = vspace.CallBackwardFunction(trace.backward_function,
+ out_gradients, &in_gradients);
+ if (!s.ok()) {
+ VLOG(1) << "Gradient function failed.";
+ cleanup();
+ return s;
+ }
+ VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
+ << trace.input_tensor_id.size() << " sources";
+ for (int i = 0; i < in_gradients.size(); ++i) {
+ const int64 id = trace.input_tensor_id[i];
+ if (in_gradients[i] != nullptr) {
+ auto& unaggregated_grads = gradients[id];
+ unaggregated_grads.push_back(in_gradients[i]);
+ if (unaggregated_grads.size() > kMinAggregateCount) {
+ auto size_it = gradients_size.find(id);
+ int64 size;
+ if (size_it == gradients_size.end()) {
+ size = vspace.NumElements(unaggregated_grads[0]);
+ gradients_size.emplace(id, size);
+ } else {
+ size = size_it->second;
+ }
+ if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
+ void* tensor = vspace.AggregateGradients(unaggregated_grads);
+ unaggregated_grads.clear();
+ unaggregated_grads.push_back(tensor);
+ }
+ }
+ }
+ auto usage_count_it = state.tensor_usage_counts.find(id);
+ if (usage_count_it == state.tensor_usage_counts.end()) {
+ VLOG(1) << "Tensor " << id << " not used";
+ continue;
+ }
+ usage_count_it->second--;
+ if (usage_count_it->second > 0) {
+ VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
+ continue;
+ }
+ auto tape_it = tensor_tape_.find(id);
+ if (tape_it == tensor_tape_.end()) {
+ VLOG(1) << "Tensor " << id
+ << " has no associated op. Deleting gradient";
+ auto grad_it = gradients.find(id);
+ if (grad_it != gradients.end()) {
+ for (auto g : grad_it->second) {
+ vspace.DeleteTensor(g);
+ }
+ gradients.erase(grad_it);
+ }
+ continue;
+ }
+ const int64 op_id = tape_it->second;
+ if (op_id == -1) {
+ VLOG(1) << "Tensor " << id << " is source";
+ continue;
+ }
+ auto missing_it = state.op_missing_tensor.find(op_id);
+ if (missing_it != state.op_missing_tensor.end()) {
+ missing_it->second--;
+ VLOG(1) << "Op " << op_id << " missing " << missing_it->second
+ << " output gradients";
+ if (missing_it->second == 0) {
+ op_stack.push_back(op_id);
+ }
+ }
+ }
+ }
+ CHECK(state.op_tape.empty());
+ result->reserve(sources.size());
+ for (auto is : id_sources) {
+ auto grad_it = gradients.find(is);
+ if (grad_it == gradients.end()) {
+ result->push_back(nullptr);
+ } else {
+ if (grad_it->second.size() == 1) {
+ result->push_back(grad_it->second[0]);
+ } else {
+ result->push_back(vspace.AggregateGradients(grad_it->second));
+ }
+ gradients.erase(grad_it);
+ }
+ }
+ VLOG(1) << "Final gradients size: " << gradients.size();
+ for (auto grad_pair : gradients) {
+ for (const auto& g : grad_pair.second) {
+ vspace.DeleteTensor(g);
+ }
+ }
+ return Status::OK();
}
} // namespace eager
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index df51f300eb..2bb62a7ab3 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -57,11 +57,57 @@ using TensorTape = std::unordered_map<int64, int64>;
// Map from operation-id to tape entry.
using OpTape = std::unordered_map<int64, OpTapeEntry>;
+// Operations the tape needs to perform on tensors to do backpropagation. Named
+// "vspace" because a subset of these are related to a vector space, such as
+// adding gradients, getting zeroes, etc. Currently cannot be implemented
+// without using tensorflow python code, hence left unspecified here.
+//
+// We currently use void* for tensors, backward functions, and gradients (which
+// can be but are not required to be tensors). TODO(apassos) replace this first
+// with templates to allow for pyobject specialization in the client followed by
+// a TFE_TensorHandle specialization, which is blocked by quite a few things
+// still.
+class VSpace {
+ public:
+ virtual ~VSpace() {}
+
+ // Returns the number of elements in the tensor.
+ virtual int64 NumElements(void* tensor) const = 0;
+
+ // Consumes references to the tensors in the gradient_tensors list and returns
+ // a tensor with the result.
+ virtual void* AggregateGradients(
+ gtl::ArraySlice<void*> gradient_tensors) const = 0;
+
+ // Returns a tensor of the right shape and dtype filled with zeros.
+ virtual void* Zeros(TensorShape shape, DataType dtype) const = 0;
+
+ // Returns a Tensor which is filled with ones and like the input.
+ virtual void* OnesLike(void*) const = 0;
+
+ // Returns an integer which is a unique-to-within-this-program handle for this
+ // tensor.
+ virtual int64 TensorId(void* tensor) const = 0;
+
+ // Calls the passed-in backward function.
+ virtual Status CallBackwardFunction(void* backward_function,
+ gtl::ArraySlice<void*> output_gradients,
+ std::vector<void*>* result) const = 0;
+
+ // Deletes the input tensor.
+ virtual void DeleteTensor(void* tensor) const = 0;
+};
+
// Traces the execution of operations, doing eager garbage collection, and
// exporting a full trace so other code can do backpropagation. Not thread-safe.
class GradientTape {
public:
GradientTape() {}
+ ~GradientTape() {
+ for (const auto& pair : op_tape_) {
+ pair.second.backward_function_deleter();
+ }
+ }
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
@@ -75,10 +121,14 @@ class GradientTape {
void DeleteTrace(int64 tensor_id);
- // Note: it is only valid to call Export once per tape, and after calling
- // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch,
- // Record, and Delete have undefined behavior).
- std::pair<TensorTape, OpTape> Export();
+ // Consumes the internal state of the tape (so cannot be called more than
+ // once) and produces the gradient of the target tensors with respect to the
+ // source tensors. The output gradients are used if not empty and not
+ // null. The result is populated with one tensor per target element.
+ Status Gradient(const VSpace& vspace, gtl::ArraySlice<void*> target,
+ gtl::ArraySlice<void*> sources,
+ gtl::ArraySlice<void*> output_gradients,
+ std::vector<void*>* result);
private:
TensorTape tensor_tape_;
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index bcd1e1d0dc..c36647b21c 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -14,11 +14,16 @@ cc_library(
"pywrap_tensor.cc",
"pywrap_tfe_src.cc",
],
- hdrs = ["pywrap_tfe.h"],
+ hdrs = [
+ "pywrap_tensor.h",
+ "pywrap_tfe.h",
+ ],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
+ "//tensorflow/c:c_api_internal",
"//tensorflow/c/eager:c_api",
+ "//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tape",
"//tensorflow/core:lib",
"//tensorflow/python:ndarray_tensor",
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 86b3776b8c..111d7cef56 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -727,11 +727,23 @@ def _num_elements(grad):
raise ValueError("`grad` not a Tensor or IndexedSlices.")
+_last_shape_dtype = [None, None]
+_last_zero = [None]
+
+
+def _zeros(shape, dtype):
+ """Wraps array_ops.zeros to cache last zero for a given shape and dtype."""
+ if [shape, dtype] != _last_shape_dtype:
+ _last_shape_dtype[:] = [shape, dtype]
+ _last_zero[0] = array_ops.zeros(shape, dtype)
+ return _last_zero[0]
+
+
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
tensor_id=ops.tensor_id,
- zeros=array_ops.zeros,
+ zeros=_zeros,
ones_like=lambda x: ops.convert_to_tensor(array_ops.ones_like(x)))
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index ed54b8e12e..ec9a185b73 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -24,11 +24,11 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
-from tensorflow.python.eager import imperative_grad
from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -41,7 +41,6 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import training
-from tensorflow.python.util import compat
class BackpropTest(test.TestCase):
@@ -103,6 +102,18 @@ class BackpropTest(test.TestCase):
grad_fn = backprop.gradients_function(f)
self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
+ def testErrors(self):
+
+ @custom_gradient.custom_gradient
+ def f(x):
+ def grad(_):
+ raise RuntimeError('x')
+ return x, grad
+
+ # TODO(apassos) raise the right error here
+ with self.assertRaises(errors_impl.InternalError):
+ backprop.gradients_function(f)(constant_op.constant(1.0))
+
def testImplicitGradOverEmbeddingLookup(self):
batch_size = 8
embedding_size = 512
@@ -483,48 +494,6 @@ class BackpropTest(test.TestCase):
initial_value=1., name='testSameObjectForMultipleArguments.Variable')
self.assertAllEqual([1., 1.], np_g(v, v))
- def testEarlyGradAggregation(self):
- # Needs to be a list so mutations by the callback affect this function.
- add_n = []
- def callback(op_type, unused_1, unused_2, unused_3, unused_4):
- if compat.as_bytes(op_type) == compat.as_bytes('AddN'):
- add_n.append(1)
- context.context().add_post_execution_callback(callback)
-
- v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0),
- name='v')
- def fn():
- outputs = []
- for _ in range(20):
- outputs.append(v * constant_op.constant(2.0))
- return math_ops.add_n(outputs)
-
- # By default the aggregation count is 2.
- _ = backprop.implicit_grad(fn)()[0][1]
- self.assertEqual(len(add_n), 2)
- del add_n[:]
-
- # Reduce the aggregation limit, cause the backprop to do some
- # early aggregation.
- # pylint: disable=protected-access
- old_cnt = imperative_grad._MIN_AGGREGATE_COUNT
- old_bytes = imperative_grad._MIN_AGGREGATE_BYTES
- imperative_grad._MIN_AGGREGATE_COUNT = 10
- imperative_grad._MIN_AGGREGATE_BYTES = 1
- _ = backprop.implicit_grad(fn)()
- self.assertEqual(len(add_n), 6)
- del add_n[:]
-
- # Aggregation is also limited by the memory.
- imperative_grad._MIN_AGGREGATE_BYTES = 10000
- _ = backprop.implicit_grad(fn)()
- self.assertEqual(len(add_n), 2)
-
- imperative_grad._MIN_AGGREGATE_COUNT = old_cnt
- imperative_grad._MIN_AGGREGATE_BYTES = old_bytes
- # pylint: enable=protected-access
- context.context().clear_post_execution_callbacks()
-
def testImplicitGradientsCustomGradientAndCachedVariableValue(self):
@custom_gradient.custom_gradient
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index c87719f84a..8932b7157b 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -20,102 +20,8 @@ from __future__ import print_function
import collections
-from tensorflow.python.eager import tape as tape_module
-
-
-# Terminology:
-#
-# - op: a possibly composite operation, which has an entry in the tape
-# - target: dy in dx/dy
-# - source: dx in dx/dy
-# - tensor: one of the many inputs or outputs of an operation
-#
-# Below here we do the gradient algorithm. It works as follows:
-#
-# First we filter the tape to just the subset of operations we want to
-# differentiate. In the process of doing so we count how many times each Tensor
-# is used as an input to an op (so we know when we're done computing gradients
-# for that Tensor). We also count, for each tape entry, how many of its output
-# Tensors need gradients to be computed (Tensors which are not used do not need
-# any gradients to be computed).
-#
-# Finally, we start a backprop stack with a set of tape entries for which we
-# have all gradients available. This set usually is a subset of the set of
-# targets (not all since targets which have outputs in the tape will not have
-# gradients available initially).
-#
-# Then we repeatedly pop an entry from the stack, run its backprop, and update
-# the gradients of its inputs. Once we have computed all gradients for a single
-# input we can mark this input as done, and this can trigger adding an entry to
-# the stack if all outputs of that entry are now done.
-#
-# When the stack is empty we have gradients for all tensors we're interested in.
-def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources):
- """Filters the tape to only include relevant entries and counts tensor usages.
-
- Args:
- vspace: information about the space we're differentiating in.
- target: the target to optimize.
- tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
- op_to_entry: Map from op id to a tape.TapeEntry object
- id_sources: the ids of the sources wrt the gradient is being taken.
-
- Returns:
- usage counts (how many entries downstream from a tensor use it)
- op_to_entry_map: entry map (a filtered tape, with only the relevant
- entries),
- missing: map from tensor id to how many downstream gradients still need
- to be computed before this tensor's gradient can be computed.
- """
- tensor_stack = [vspace.tensor_id(x) for x in target]
- tensor_usage_counts = {}
- o_to_e = {} # Copy of just the bits we need from op_to_entry
- while tensor_stack:
- t = tensor_stack.pop()
- op = tensor_to_op.get(t, None)
- # op is None or -1 if the tensor is a source (i.e. was watched directly)
- if op is None or op == -1 or op in o_to_e:
- continue
- op_trace = tape_module.TapeEntry(*op_to_entry[op])
- o_to_e[op] = op_trace
- for it in op_trace.input_ids:
- if it in tensor_usage_counts:
- tensor_usage_counts[it] += 1
- else:
- tensor_usage_counts[it] = 1
- if it not in id_sources and it in tensor_to_op:
- tensor_stack.append(it)
- op_missing_tensor_counts = collections.defaultdict(int)
- for t in tensor_usage_counts:
- if t in tensor_to_op and tensor_to_op[t] is not None:
- op_missing_tensor_counts[tensor_to_op[t]] += 1
- return tensor_usage_counts, o_to_e, op_missing_tensor_counts
-
-
-def _initialize_backprop_stack(op_to_entry, op_missing_tensor):
- """Returns the set of tape entries which are available for backprop."""
- ready_ops = []
- for op in op_to_entry:
- if op not in op_missing_tensor:
- ready_ops.append(op)
- return ready_ops
-
-
-def _initial_gradients(vspace, target, output_gradients, tensor_usage_counts):
- """Computes the initial gradients for each Tensor."""
- # Initialize the backprop stack
- gradients = collections.defaultdict(list)
- for i, t in enumerate(target):
- if vspace.tensor_id(t) in tensor_usage_counts:
- # Can't provide a gradient of something we're trying to differentiate
- assert output_gradients is None or output_gradients[i] is None
- else:
- if output_gradients is None or output_gradients[i] is None:
- out_grad = vspace.ones_like(t)
- else:
- out_grad = output_gradients[i]
- gradients[vspace.tensor_id(t)].append(out_grad)
- return gradients
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import errors
VSpace = collections.namedtuple(
@@ -123,13 +29,6 @@ VSpace = collections.namedtuple(
["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones_like"])
-# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total
-# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation
-# so as to release the gradient tensor to save memory.
-_MIN_AGGREGATE_COUNT = 4
-_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024
-
-
def imperative_grad(
vspace,
tape,
@@ -161,89 +60,6 @@ def imperative_grad(
or if only non-differentiable functions of the source were used in the
computation of target.
"""
- tensor_to_op, op_to_entry = tape.export()
- # This overwrites the op_to_entry variable, which will release all memory used
- # to keep traces that are irrelevant to the gradient computation we're doing
- # here.
- id_sources = [vspace.tensor_id(t) for t in sources]
- tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
- vspace, target, tensor_to_op, op_to_entry, id_sources)
- ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
- gradients = _initial_gradients(vspace, target, output_gradients,
- tensor_usage_counts)
- gradients_size = dict()
- # Now exhaust the backprop stack
- while ready_ops:
- op = ready_ops.pop()
- op_trace = op_to_entry.pop(op)
- out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
-
- # Cache the last used zero tensor. We reuse it if the next one
- # we need is of the same shape and dtype. This is very helpful in
- # large splits and should have negligible overhead in other cases.
- last_shape_and_dtype = None
- last_zeros = None
- for i in range(len(out_gradients)):
- if out_gradients[i] is None:
- # TODO(apassos) this should be in the right device
- none_indices = _grad_fn_accepts_none_for_indices.get(
- op_trace.op_type, None)
- if none_indices is None or i not in none_indices:
- shape_and_dtype = op_trace.output_shape_and_dtype[i]
- if shape_and_dtype != last_shape_and_dtype:
- last_shape_and_dtype = shape_and_dtype
- last_zeros = vspace.zeros(*shape_and_dtype)
- out_gradients[i] = last_zeros
- else:
- out_gradients[i] = vspace.aggregate_fn(out_gradients[i])
-
- in_gradients = op_trace.backward_function(*(out_gradients))
- for i, t in enumerate(op_trace.input_ids):
- if in_gradients[i] is not None:
- t_grads = gradients.setdefault(t, [])
- t_grads.append(in_gradients[i])
- if len(t_grads) >= _MIN_AGGREGATE_COUNT:
- if t not in gradients_size:
- gradients_size[t] = vspace.num_elements_fn(t_grads[-1])
- size = gradients_size[t]
-
- if len(t_grads) * size * 4 > _MIN_AGGREGATE_BYTES:
- t_grads[:] = [vspace.aggregate_fn(t_grads)]
- if tensor_usage_counts.get(t, 0) > 0:
- tensor_usage_counts[t] -= 1
- if (t in tensor_to_op
- and tensor_usage_counts[t] == 0
- and t not in id_sources):
- in_op = tensor_to_op[t]
- if in_op is None or in_op == -1:
- continue
- if op_missing_tensor.get(in_op, 0) > 0:
- op_missing_tensor[in_op] -= 1
- if op_missing_tensor.get(in_op, 0) == 0:
- ready_ops.append(in_op)
- result = []
- for i, s in enumerate(sources):
- g = gradients.get(vspace.tensor_id(s), None)
- if g is None:
- result.append(None)
- else:
- result.append(vspace.aggregate_fn(g))
- return result
-
-
-# TODO(agarwal): use an automatic mechanism for handling None arguments to
-# gradient functions.
-# Some gradient functions can accept None arguments for gradients. The following
-# maps the operation name to the indices at which the corresponding gradient
-# function can accept None values.
-# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
-# during backprop. However the gradient function uses only the first of those
-# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
-# indicates that only the gradient corresponding to index 0 is used, and the
-# gradient values at indices 1-4 are ignored (and hence can be None). The
-# backprop algorithm can then leverage this by not constructing zeros to
-# pass for those indices.
-_grad_fn_accepts_none_for_indices = {
- "SoftmaxCrossEntropyWithLogits": [1],
- "FusedBatchNorm": [1, 2, 3, 4]
-}
+ with errors.raise_exception_on_not_ok_status() as status:
+ return pywrap_tensorflow.TFE_Py_TapeGradient(
+ tape._tape, vspace, target, sources, output_gradients, status) # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index ca283862f9..653f3ef84e 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_seq_tensor.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
+#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/c/c_api.h"
@@ -573,7 +574,7 @@ bool EagerTensor_CheckExact(const PyObject* o) {
return Py_TYPE(o) == EagerTensorType;
}
-TFE_TensorHandle* EagerTensorHandle(const PyObject* o) {
+TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
return reinterpret_cast<const EagerTensor*>(o)->handle;
}
@@ -594,6 +595,11 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
return reinterpret_cast<PyObject*>(t);
}
+tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
+ CHECK(EagerTensor_CheckExact(tensor));
+ return reinterpret_cast<const EagerTensor*>(tensor)->id;
+}
+
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
if (!PyType_Check(base_class)) {
PyErr_SetString(
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
new file mode 100644
index 0000000000..aa1efdd1b8
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -0,0 +1,25 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
+#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
+
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/python/lib/core/numpy.h"
+
+bool EagerTensor_CheckExact(const PyObject* o);
+tensorflow::int64 EagerTensor_id(const PyObject* tensor);
+
+#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 1d03df2933..6705483f3b 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -81,7 +81,7 @@ bool EagerTensor_CheckExact(const PyObject* o);
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle);
// Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
-TFE_TensorHandle* EagerTensorHandle(const PyObject* o);
+TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
// Creates the `EagerTensor` class by subclassing `base_class` and returns the
// newly created type, or nullptr on error.
@@ -103,7 +103,16 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
PyObject* output_tensors,
PyObject* input_tensor_ids,
PyObject* backward_function);
-PyObject* TFE_Py_TapeExport(PyObject* tape);
+
+// Computes a gradient based on information recorded on the tape.`tape` must
+// have been produced by TFE_Py_NewTape. `vspace` must be a
+// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python
+// lists of Tensor objects. `output_gradients` is either None or a python list
+// of either Tensor or None, and if not None should have the same length as
+// target.
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
+ PyObject* target, PyObject* sources,
+ PyObject* output_gradients, TF_Status* status);
// Returns an EagerTensor of dimension [len(`tensor_list`)] containing
// the `slice_dim`'th dimension of each tensor in `tensor_list`. In other words,
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 7456eb10f8..a00a7615d7 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -16,10 +16,13 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tape.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/python/eager/pywrap_tensor.h"
using tensorflow::string;
@@ -515,18 +518,50 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
}
PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors) {
+ if (tensors == Py_None) {
+ Py_RETURN_FALSE;
+ }
+ PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+ if (seq == nullptr) {
+ return nullptr;
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ // TODO(apassos) consider not building a list and changing the API to check
+ // each tensor individually.
+ std::vector<tensorflow::int64> tensor_ids;
+ tensor_ids.reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+ if (EagerTensor_CheckExact(item)) {
+ tensor_ids.push_back(EagerTensor_id(item));
+ } else {
+ PyObject* id_field = PyObject_GetAttrString(item, "_id");
+ if (id_field == nullptr) {
+ return nullptr;
+ }
+ tensor_ids.push_back(MakeInt(id_field));
+ Py_DECREF(id_field);
+ }
+ }
+ Py_DECREF(seq);
TFE_Py_Tape* tape = reinterpret_cast<TFE_Py_Tape*>(py_tape);
- return PyBool_FromLong(tape->tape->ShouldRecord(MakeIntList(tensors)));
+ if (tape->tape->ShouldRecord(tensor_ids)) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
}
void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id) {
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
-// TODO(apassos) have a fast path for eager tensors here which gets information
-// from the handle instead of from the python object, and use this only for the
-// case of graph tensors.
static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
+ if (EagerTensor_CheckExact(tensor)) {
+ TFE_TensorHandle* t = EagerTensor_Handle(tensor);
+ tensorflow::int64 id = EagerTensor_id(tensor);
+ return tensorflow::eager::TapeTensor{id, t->t.dtype(), t->t.shape()};
+ }
PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
tensorflow::int64 id = MakeInt(id_field);
Py_DECREF(id_field);
@@ -592,64 +627,224 @@ void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id) {
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->DeleteTrace(tensor_id);
}
-// TODO(apassos) when backprop.py moves to C most of this exporting logic can
-// disappear.
-PyObject* TFE_Py_TapeExport(PyObject* tape) {
- std::pair<tensorflow::eager::TensorTape, tensorflow::eager::OpTape> exported =
- reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Export();
- PyObject* tensor_tape = PyDict_New();
- for (const auto& pair : exported.first) {
- PyObject* tid = PyLong_FromLong(pair.first);
- PyObject* opid = PyLong_FromLong(pair.second);
- PyDict_SetItem(tensor_tape, tid, opid);
- Py_DECREF(tid);
- Py_DECREF(opid);
- }
-
- PyObject* op_tape = PyDict_New();
- for (const auto& pair : exported.second) {
- PyObject* opid = PyLong_FromLong(pair.first);
- const auto& entry = pair.second;
- PyObject* op_type = PyBytes_FromString(entry.op_type.c_str());
- PyObject* output_ids = PyList_New(entry.output_tensor_info.size());
- for (int i = 0; i < entry.output_tensor_info.size(); ++i) {
- PyObject* tid = PyLong_FromLong(entry.output_tensor_info[i].id);
- PyList_SET_ITEM(output_ids, i, tid);
+// TODO(apassos): cache the attribute lookups as member variables and decref
+// them in the destructor.
+class PyVSpace : public tensorflow::eager::VSpace {
+ public:
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
+
+ tensorflow::Status Initialize() {
+ num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
+ if (num_elements_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
+ if (aggregate_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
+ if (zeros_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
}
- PyObject* input_ids = PyList_New(entry.input_tensor_id.size());
- for (int i = 0; i < entry.input_tensor_id.size(); ++i) {
- PyObject* tid = PyLong_FromLong(entry.input_tensor_id[i]);
- PyList_SET_ITEM(input_ids, i, tid);
+ ones_like_ = PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_),
+ "ones_like");
+ if (ones_like_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
}
- PyObject* backward_function =
- reinterpret_cast<PyObject*>(entry.backward_function);
- PyObject* output_shape_and_dtype =
- PyList_New(entry.output_tensor_info.size());
- for (int i = 0; i < entry.output_tensor_info.size(); ++i) {
- const tensorflow::TensorShape& shape = entry.output_tensor_info[i].shape;
- PyObject* shape_list = PyList_New(shape.dims());
- for (int j = 0; j < shape.dims(); ++j) {
- PyList_SET_ITEM(shape_list, j, PyLong_FromLong(shape.dim_size(j)));
+ return tensorflow::Status::OK();
+ }
+
+ ~PyVSpace() override {
+ Py_XDECREF(num_elements_);
+ Py_XDECREF(aggregate_fn_);
+ Py_XDECREF(zeros_);
+ Py_XDECREF(ones_like_);
+ }
+
+ tensorflow::int64 NumElements(void* tensor) const final {
+ PyObject* arglist =
+ Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
+ PyObject* result = PyEval_CallObject(num_elements_, arglist);
+ tensorflow::int64 r = MakeInt(result);
+ Py_DECREF(result);
+ Py_DECREF(arglist);
+ return r;
+ }
+
+ void* AggregateGradients(
+ tensorflow::gtl::ArraySlice<void*> gradient_tensors) const final {
+ PyObject* list = PyList_New(gradient_tensors.size());
+ for (int i = 0; i < gradient_tensors.size(); ++i) {
+ // Note: stealing a reference to the gradient tensors.
+ CHECK(gradient_tensors[i] != nullptr);
+ CHECK(gradient_tensors[i] != Py_None);
+ PyList_SET_ITEM(list, i,
+ reinterpret_cast<PyObject*>(gradient_tensors[i]));
+ }
+ PyObject* arglist = Py_BuildValue("(O)", list);
+ CHECK(arglist != nullptr);
+ PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
+ Py_DECREF(arglist);
+ Py_DECREF(list);
+ return result;
+ }
+
+ void* Zeros(tensorflow::TensorShape shape,
+ tensorflow::DataType dtype) const final {
+ PyObject* py_shape = PyTuple_New(shape.dims());
+ for (int i = 0; i < shape.dims(); ++i) {
+ PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
+ }
+ PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(zeros_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return reinterpret_cast<void*>(result);
+ }
+
+ void* OnesLike(void* tensor) const final {
+ PyObject* arg_list = Py_BuildValue("(O)", tensor);
+ PyObject* result = PyEval_CallObject(ones_like_, arg_list);
+ if (result == nullptr) {
+ VLOG(1) << "Call to ones_like failed";
+ }
+ Py_DECREF(arg_list);
+ return reinterpret_cast<void*>(result);
+ }
+
+ tensorflow::int64 TensorId(void* tensor) const final {
+ PyObject* py_tensor = reinterpret_cast<PyObject*>(tensor);
+ PyObject* id_field = PyObject_GetAttrString(py_tensor, "_id");
+ tensorflow::int64 id = MakeInt(id_field);
+ Py_DECREF(id_field);
+ return id;
+ }
+
+ tensorflow::Status CallBackwardFunction(
+ void* backward_function,
+ tensorflow::gtl::ArraySlice<void*> output_gradients,
+ std::vector<void*>* result) const final {
+ PyObject* grads = PyTuple_New(output_gradients.size());
+ for (int i = 0; i < output_gradients.size(); ++i) {
+ if (output_gradients[i] == nullptr) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(grads, i, Py_None);
+ } else {
+ PyTuple_SET_ITEM(grads, i,
+ reinterpret_cast<PyObject*>(output_gradients[i]));
}
- PyObject* type_enum = PyLong_FromLong(entry.output_tensor_info[i].dtype);
- PyObject* tuple = PyTuple_Pack(2, shape_list, type_enum);
- Py_DECREF(shape_list);
- Py_DECREF(type_enum);
- PyList_SET_ITEM(output_shape_and_dtype, i, tuple);
}
- PyObject* opinfo = PyTuple_Pack(5, op_type, output_ids, input_ids,
- backward_function, output_shape_and_dtype);
- Py_DECREF(op_type);
- Py_DECREF(output_ids);
- Py_DECREF(input_ids);
+ PyObject* py_result = PyEval_CallObject(
+ reinterpret_cast<PyObject*>(backward_function), grads);
+ Py_DECREF(grads);
Py_DECREF(backward_function);
- Py_DECREF(output_shape_and_dtype);
- PyDict_SetItem(op_tape, opid, opinfo);
- Py_DECREF(opid);
- Py_DECREF(opinfo);
- }
- PyObject* retval = PyTuple_Pack(2, tensor_tape, op_tape);
- Py_DECREF(tensor_tape);
- Py_DECREF(op_tape);
- return retval;
+ if (py_result == nullptr) {
+ VLOG(1) << "Gradient function threw exceptions";
+ if (VLOG_IS_ON(1)) {
+ PyErr_Print();
+ }
+ return tensorflow::errors::Internal("gradient function threw exceptions");
+ }
+ result->clear();
+ PyObject* seq =
+ PySequence_Fast(py_result, "expected a sequence of gradients");
+ if (seq == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "gradient function did not return a list");
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ VLOG(1) << "Gradient length is " << len;
+ result->reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+ if (item == Py_None) {
+ result->push_back(nullptr);
+ } else {
+ Py_INCREF(item);
+ result->push_back(item);
+ }
+ }
+ Py_DECREF(seq);
+ Py_DECREF(py_result);
+ return tensorflow::Status::OK();
+ }
+
+ void DeleteTensor(void* tensor) const final {
+ Py_XDECREF(reinterpret_cast<PyObject*>(tensor));
+ }
+
+ private:
+ PyObject* py_vspace_;
+
+ PyObject* num_elements_;
+ PyObject* aggregate_fn_;
+ PyObject* zeros_;
+ PyObject* ones_like_;
+};
+
+std::vector<void*> MakeTensorList(PyObject* tensors) {
+ PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+ if (seq == nullptr) {
+ return {};
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ std::vector<void*> list;
+ list.reserve(len);
+ for (int i = 0; i < len; ++i) {
+ list.push_back(PySequence_Fast_GET_ITEM(seq, i));
+ }
+ Py_DECREF(seq);
+ return list;
+}
+
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
+ PyObject* target, PyObject* sources,
+ PyObject* output_gradients, TF_Status* status) {
+ PyVSpace c_vspace(vspace);
+ if (!c_vspace.Initialize().ok()) {
+ return nullptr;
+ }
+
+ std::vector<void*> target_vec = MakeTensorList(target);
+ if (PyErr_Occurred()) {
+ return nullptr;
+ }
+ std::vector<void*> sources_vec = MakeTensorList(sources);
+ if (PyErr_Occurred()) {
+ return nullptr;
+ }
+ std::vector<void*> outgrad_vec;
+ if (output_gradients != Py_None) {
+ outgrad_vec = MakeTensorList(output_gradients);
+ if (PyErr_Occurred()) {
+ return nullptr;
+ }
+ for (void* tensor : outgrad_vec) {
+ // Calling the backward function will eat a reference to the tensors in
+ // outgrad_vec, so we need to increase their reference count.
+ Py_INCREF(reinterpret_cast<PyObject*>(tensor));
+ }
+ }
+ TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
+ std::vector<void*> result;
+ status->status = tape_obj->tape->Gradient(c_vspace, target_vec, sources_vec,
+ outgrad_vec, &result);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ if (!result.empty()) {
+ PyObject* py_result = PyList_New(result.size());
+ for (int i = 0; i < result.size(); ++i) {
+ if (result[i] == nullptr) {
+ Py_INCREF(Py_None);
+ result[i] = Py_None;
+ }
+ PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
+ }
+ return py_result;
+ }
+ Py_INCREF(Py_None);
+ return Py_None;
}
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index c16aa8c2f7..a06f5e1a67 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -72,7 +72,7 @@ class Tape(object):
True if any of the tensors is in the tape.
"""
return pywrap_tensorflow.TFE_Py_TapeShouldRecord(
- self._tape, [x._id for x in tensors]) # pylint: disable=protected-access
+ self._tape, tensors)
def watch(self, tensor):
"""Adds a tensor to the tape."""
@@ -99,16 +99,6 @@ class Tape(object):
"""Deletes any trace we have for this tensor."""
self._delete_tensor_id(tensor_id)
- def export(self):
- """Exports the internal state of this tape.
-
- Returns:
- tensor_tape: a map from tensor_id(tensor) to <identifier for op>
- responsible for generating that tensor.
- op_tape: a map from <identifier for op> to TapeEntry for that op.
- """
- return pywrap_tensorflow.TFE_Py_TapeExport(self._tape)
-
class _TapeStack(threading.local):
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index c97cb62125..b490bac66d 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -22,7 +22,6 @@ from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
-from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -166,25 +165,6 @@ class TapeTest(test.TestCase):
g, = backprop.gradients_function(fn, [0])(t)
self.assertAllEqual(g, 1.0)
- def testTapeGC(self):
- # TODO(apassos) figure out how to test this without using tape internal
- # APIs.
- tape.push_new_tape()
-
- def f():
- x = constant_op.constant(1.0)
- tape.watch(x)
- x = gradient_is_constant(x)
- x = gradient_is_constant(x)
- x = gradient_is_constant(x)
-
- f()
- t = tape.pop_tape()
- tensor_tape, op_tape = t.export()
- self.assertEqual(len(tensor_tape), 1) # The watched tensor will remain on
- # the tape
- self.assertEqual(len(op_tape), 0) # No operations should remain on the tape
-
def testCustomGradientGraphMode(self):
with context.graph_mode(), self.test_session():
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 637f738fed..cbacf458a0 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -29,7 +29,7 @@ limitations under the License.
%rename("%s") TFE_Py_TapeWatch;
%rename("%s") TFE_Py_TapeDeleteTrace;
%rename("%s") TFE_Py_TapeRecordOperation;
-%rename("%s") TFE_Py_TapeExport;
+%rename("%s") TFE_Py_TapeGradient;
%rename("%s") TFE_NewContextOptions;
%rename("%s") TFE_ContextOptionsSetConfig;
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
@@ -125,7 +125,7 @@ limitations under the License.
SWIG_fail;
}
if (EagerTensor_CheckExact(elem)) {
- (*$1)[i] = EagerTensorHandle(elem);
+ (*$1)[i] = EagerTensor_Handle(elem);
} else {
SWIG_exception_fail(SWIG_TypeError,
"provided list of inputs contains objects other "