aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager/tape.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/c/eager/tape.h')
-rw-r--r--tensorflow/c/eager/tape.h501
1 files changed, 8 insertions, 493 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 29d73c5ca4..df51f300eb 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -19,7 +19,6 @@ limitations under the License.
// maintains the data structures required to do so.
#include <unordered_map>
-#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
@@ -37,14 +36,13 @@ struct TapeTensor {
};
// Represents an entry in the tape.
-template <typename BackwardFunction>
struct OpTapeEntry {
string op_type;
std::vector<TapeTensor> output_tensor_info;
std::vector<int64> input_tensor_id;
// TODO(apassos) consider narrowing down this interface.
- BackwardFunction* backward_function;
+ void* backward_function;
// Should be called before deleting the backward function. TODO(apassos) use
// unique_ptrs to ensure this happens.
@@ -57,68 +55,13 @@ struct OpTapeEntry {
using TensorTape = std::unordered_map<int64, int64>;
// Map from operation-id to tape entry.
-template <typename BackwardFunction>
-using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
-
-// Operations the tape needs to perform on tensors to do backpropagation. Named
-// "vspace" because a subset of these are related to a vector space, such as
-// adding gradients, getting zeroes, etc. Currently cannot be implemented
-// without using tensorflow python code, hence left unspecified here.
-//
-// Gradient is the type returned by gradient functions. In Python TF it's either
-// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
-// to allow their size to be computed and they need to be passable to a backward
-// function and deleted (as the backprop code creates lots of gradients the user
-// is not interested in).
-//
-// BackwardFunction needs to be a closure which stores intermediate activations
-// from the forward computation and calls a vector-jacobian product function
-// (also known as adjoint function) to compute, given downstream gradients,
-// upstream gradients.
-//
-// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
-// specialization, which is blocked by quite a few things needing to loop back
-// into python now.
-template <typename Gradient, typename BackwardFunction>
-class VSpace {
- public:
- virtual ~VSpace() {}
-
- // Returns the number of elements in the gradient tensor.
- virtual int64 NumElements(Gradient* tensor) const = 0;
-
- // Consumes references to the tensors in the gradient_tensors list and returns
- // a tensor with the result.
- virtual Gradient* AggregateGradients(
- gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
-
- // Returns a tensor of the right shape and dtype filled with zeros.
- virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
-
- // Returns a Tensor which is filled with ones and like the input.
- virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0;
-
- // Calls the passed-in backward function.
- virtual Status CallBackwardFunction(
- BackwardFunction* backward_function,
- gtl::ArraySlice<Gradient*> output_gradients,
- std::vector<Gradient*>* result) const = 0;
-
- // Deletes the input tensor.
- virtual void DeleteGradient(Gradient* gradient) const = 0;
-};
+using OpTape = std::unordered_map<int64, OpTapeEntry>;
// Traces the execution of operations, doing eager garbage collection, and
// exporting a full trace so other code can do backpropagation. Not thread-safe.
-template <typename Gradient, typename BackwardFunction>
class GradientTape {
public:
GradientTape() {}
- ~GradientTape() {
- for (const auto& pair : op_tape_) {
- pair.second.backward_function_deleter();
- }
- }
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
@@ -127,24 +70,19 @@ class GradientTape {
void RecordOperation(const string& op_type,
gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
- BackwardFunction* backward_function,
+ void* backward_function,
const std::function<void()>& backward_function_deleter);
void DeleteTrace(int64 tensor_id);
- // Consumes the internal state of the tape (so cannot be called more than
- // once) and produces the gradient of the target tensors with respect to the
- // source tensors. The output gradients are used if not empty and not
- // null. The result is populated with one tensor per target element.
- Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<int64> source_tensor_id,
- gtl::ArraySlice<Gradient*> output_gradients,
- std::vector<Gradient*>* result);
+ // Note: it is only valid to call Export once per tape, and after calling
+ // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch,
+ // Record, and Delete have undefined behavior).
+ std::pair<TensorTape, OpTape> Export();
private:
TensorTape tensor_tape_;
- OpTape<BackwardFunction> op_tape_;
+ OpTape op_tape_;
int64 next_op_id_{0};
// Map from tensor id to number of remaining usages (i.e. how many entries in
@@ -152,429 +90,6 @@ class GradientTape {
std::unordered_map<int64, int64> tensor_usage_;
};
-// Template instantiations here
-
-template <typename Gradient, typename BackwardFunction>
-bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
- gtl::ArraySlice<int64> tensor_ids) {
- for (int64 i : tensor_ids) {
- if (tensor_tape_.find(i) != tensor_tape_.end()) {
- return true;
- }
- }
- return false;
-}
-
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
- tensor_tape_.emplace(tensor_id, -1);
-}
-
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
- gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
- const std::function<void()>& backward_function_deleter) {
- if (!ShouldRecord(input_tensor_id)) {
- backward_function_deleter();
- return;
- }
- std::vector<int64> ids;
- ids.reserve(input_tensor_id.size());
- for (int64 i : input_tensor_id) {
- tensor_usage_[i]++;
- ids.push_back(i);
- }
- const int64 op_id = next_op_id_++;
- std::vector<TapeTensor> tensors;
- tensors.reserve(output_tensors.size());
- for (const TapeTensor& o : output_tensors) {
- // Note: the tensor can have already been watched and hence be in the tape,
- // so we cannot check that we're inserting it here.
- tensor_tape_[o.id] = op_id;
- tensor_usage_[o.id] = 1;
- tensors.push_back(o);
- }
- op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
- op_type, tensors, ids, backward_function, backward_function_deleter};
-}
-
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
- auto it = tensor_usage_.find(tensor_id);
- if (it == tensor_usage_.end()) {
- return;
- }
- it->second--;
- if (it->second != 0) {
- return;
- }
- tensor_usage_.erase(it);
- auto tensor_op_it = tensor_tape_.find(tensor_id);
- if (tensor_op_it == tensor_tape_.end()) {
- return;
- }
- const int64 op_id = tensor_op_it->second;
- if (op_id == -1) {
- // Do not delete watched tensors.
- return;
- }
- tensor_tape_.erase(tensor_op_it);
- auto op_it = op_tape_.find(op_id);
- CHECK(op_it != op_tape_.end());
- for (const auto& output : op_it->second.output_tensor_info) {
- if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
- // Found a usage for an output, so cannot delete the op.
- return;
- }
- }
- for (int64 id : op_it->second.input_tensor_id) {
- DeleteTrace(id);
- }
- op_it->second.backward_function_deleter();
- op_tape_.erase(op_it);
-}
-
-// Terminology:
-//
-// - op: a possibly composite operation, which has an entry in the tape
-// - target: dy in dx/dy
-// - source: dx in dx/dy
-// - tensor: one of the many inputs or outputs of an operation
-//
-// Below here we do the gradient algorithm. It works as follows:
-//
-// First we filter the tape to just the subset of operations we want to
-// differentiate. In the process of doing so we count how many times each Tensor
-// is used as an input to an op (so we know when we're done computing gradients
-// for that Tensor). We also count, for each tape entry, how many of its output
-// Tensors need gradients to be computed (Tensors which are not used do not need
-// any gradients to be computed).
-//
-// Finally, we start a backprop stack with a set of tape entries for which we
-// have all gradients available. This set usually is a subset of the set of
-// targets (not all since targets which have outputs in the tape will not have
-// gradients available initially).
-//
-// Then we repeatedly pop an entry from the stack, run its backprop, and update
-// the gradients of its inputs. Once we have computed all gradients for a single
-// input we can mark this input as done, and this can trigger adding an entry to
-// the stack if all outputs of that entry are now done.
-//
-// When the stack is empty we have gradients for all tensors we're interested
-// in.
-
-namespace {
-
-template <typename BackwardFunction>
-struct BackpropInitialState {
- OpTape<BackwardFunction> op_tape;
-
- // Map from tensor ID to how many references still exist for this tensor in
- // the tape.
- std::unordered_map<int64, int64> tensor_usage_counts;
-
- // Maps from op ID to how many output tensors of this op still need to have
- // their gradients computed.
- std::unordered_map<int64, int64> op_missing_tensor;
-};
-
-template <typename BackwardFunction>
-BackpropInitialState<BackwardFunction> PrepareBackprop(
- gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
- OpTape<BackwardFunction> op_tape,
- const std::unordered_set<int64>& sources_set) {
- std::vector<int64> tensor_stack;
- tensor_stack.reserve(target.size());
- for (auto t : target) {
- tensor_stack.push_back(t);
- }
- BackpropInitialState<BackwardFunction> result;
- while (!tensor_stack.empty()) {
- int64 tensor_id = tensor_stack.back();
- tensor_stack.pop_back();
- auto op_id_it = tensor_tape.find(tensor_id);
- if (op_id_it == tensor_tape.end()) {
- continue;
- }
- int64 op_id = op_id_it->second;
- auto op_it = op_tape.find(op_id);
- auto result_op_it = result.op_tape.find(op_id);
- if (op_id == -1 || op_it == op_tape.end() ||
- result_op_it != result.op_tape.end()) {
- continue;
- }
- CHECK(result.op_tape.emplace(op_id, op_it->second).second);
- for (auto it : op_it->second.input_tensor_id) {
- auto count_it = result.tensor_usage_counts.find(it);
- if (count_it != result.tensor_usage_counts.end()) {
- count_it->second++;
- } else {
- result.tensor_usage_counts[it] = 1;
- if (sources_set.find(it) == sources_set.end() &&
- tensor_tape.find(it) != tensor_tape.end()) {
- tensor_stack.push_back(it);
- }
- }
- }
- op_tape.erase(op_it);
- }
- for (auto& pair : result.tensor_usage_counts) {
- auto it = tensor_tape.find(pair.first);
- if (it != tensor_tape.end() && it->second != -1) {
- result.op_missing_tensor[it->second] += 1;
- }
- }
- // Call destructors for all unneeded gradient functions.
- for (const auto& op_pair : op_tape) {
- op_pair.second.backward_function_deleter();
- }
- return result;
-}
-
-template <typename BackwardFunction>
-std::vector<int64> InitialStack(
- const OpTape<BackwardFunction>& op_tape,
- const std::unordered_map<int64, int64>& op_missing_tensor) {
- std::vector<int64> result;
- for (auto& op_entry : op_tape) {
- if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
- result.push_back(op_entry.first);
- }
- }
- return result;
-}
-
-template <typename Gradient, typename BackwardFunction>
-Status InitialGradients(
- const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
- const OpTape<BackwardFunction>& op_tape,
- const std::unordered_map<int64, int64>& tensor_usage_counts,
- std::unordered_map<int64, std::vector<Gradient*>>* result) {
- for (int i = 0; i < target_tensor_ids.size(); ++i) {
- const int64 id = target_tensor_ids[i];
- if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
- if (!output_gradients.empty() && output_gradients[i] != nullptr) {
- // TODO(apassos) figure out how to print debugging information here.
- return errors::InvalidArgument(
- "A gradient was provided for a tensor which is used as part of the "
- "computation.");
- }
- } else {
- if (output_gradients.empty() || output_gradients[i] == nullptr) {
- auto tensor_it = tensor_tape.find(id);
- if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
- auto op_it = op_tape.find(tensor_it->second);
- if (op_it == op_tape.end()) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid.");
- }
- bool found = false;
- for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
- found = true;
- (*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
- break;
- }
- }
- if (!found) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid.");
- }
- } else {
- // No record of the target tensor found on the tape, so no gradient
- // needs to be computed from it. Do nothing.
- }
- } else {
- (*result)[id].push_back(output_gradients[i]);
- }
- }
- }
- return Status::OK();
-}
-
-} // namespace
-
-// If over kMinAggregateCount gradients are accumulated and the total
-// memory consumption is over kMinAggregateBytes, do an early aggregation
-// so as to release the gradient tensor to save memory.
-constexpr int kMinAggregateCount = 4;
-constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
-
-template <typename Gradient, typename BackwardFunction>
-Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
- const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<int64> source_tensor_ids,
- gtl::ArraySlice<Gradient*> output_gradients,
- std::vector<Gradient*>* result) {
- std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
- source_tensor_ids.end());
- BackpropInitialState<BackwardFunction> state = PrepareBackprop(
- target_tensor_ids, tensor_tape_, std::move(op_tape_), sources_set);
- std::vector<int64> op_stack =
- InitialStack(state.op_tape, state.op_missing_tensor);
- std::unordered_map<int64, std::vector<Gradient*>> gradients;
- Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
- tensor_tape_, state.op_tape,
- state.tensor_usage_counts, &gradients);
- auto cleanup = [&state]() {
- // Release all backprop functions
- for (const auto& pair : state.op_tape) {
- pair.second.backward_function_deleter();
- }
- };
- if (!s.ok()) {
- cleanup();
- return s;
- }
- std::unordered_map<int64, int64> gradients_size;
- // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
- // time, for better CPU backprop performance.
- VLOG(1) << "Initial stack:";
- if (VLOG_IS_ON(1)) {
- for (auto t : op_stack) {
- VLOG(1) << " " << t;
- }
- }
- std::unordered_map<string, std::unordered_set<int>>
- functions_accept_none_for_indices({
- {"SoftmaxCrossEntropyWithLogits", {1}},
- {"FusedBatchNorm", {1, 2, 3, 4}},
- });
- while (!op_stack.empty()) {
- const int64 op = op_stack.back();
- VLOG(1) << "Popped " << op;
- op_stack.pop_back();
- auto op_it = state.op_tape.find(op);
- if (op_it == state.op_tape.end()) {
- // It is possible for ops to end up on the stack if they are unrelated to
- // the target; we should just skip them.
- continue;
- }
- auto trace = std::move(op_it->second);
- state.op_tape.erase(op_it);
- std::vector<Gradient*> out_gradients;
- out_gradients.reserve(trace.output_tensor_info.size());
- for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
- const int64 id = trace.output_tensor_info[i].id;
- auto grad_it = gradients.find(id);
- if (grad_it == gradients.end()) {
- auto func_name_it =
- functions_accept_none_for_indices.find(trace.op_type);
- if (func_name_it != functions_accept_none_for_indices.end() &&
- func_name_it->second.find(i) != func_name_it->second.end()) {
- out_gradients.push_back(nullptr);
- } else {
- out_gradients.push_back(
- vspace.Zeros(trace.output_tensor_info[i].shape,
- trace.output_tensor_info[i].dtype));
- }
- } else {
- out_gradients.push_back(vspace.AggregateGradients(grad_it->second));
- if (sources_set.find(grad_it->first) == sources_set.end()) {
- gradients.erase(grad_it);
- }
- }
- }
- std::vector<Gradient*> in_gradients;
- Status s = vspace.CallBackwardFunction(trace.backward_function,
- out_gradients, &in_gradients);
- if (!s.ok()) {
- VLOG(1) << "Gradient function failed.";
- cleanup();
- return s;
- }
- VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
- << trace.input_tensor_id.size() << " sources";
- for (int i = 0; i < in_gradients.size(); ++i) {
- const int64 id = trace.input_tensor_id[i];
- if (in_gradients[i] != nullptr) {
- auto& unaggregated_grads = gradients[id];
- unaggregated_grads.push_back(in_gradients[i]);
- if (unaggregated_grads.size() > kMinAggregateCount) {
- auto size_it = gradients_size.find(id);
- int64 size;
- if (size_it == gradients_size.end()) {
- size = vspace.NumElements(unaggregated_grads[0]);
- gradients_size.emplace(id, size);
- } else {
- size = size_it->second;
- }
- if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
- Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
- unaggregated_grads.clear();
- unaggregated_grads.push_back(grad);
- }
- }
- }
- auto usage_count_it = state.tensor_usage_counts.find(id);
- if (usage_count_it == state.tensor_usage_counts.end()) {
- VLOG(1) << "Tensor " << id << " not used";
- continue;
- }
- usage_count_it->second--;
- if (usage_count_it->second > 0) {
- VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
- continue;
- }
- auto tape_it = tensor_tape_.find(id);
- if (tape_it == tensor_tape_.end()) {
- VLOG(1) << "Tensor " << id
- << " has no associated op. Deleting gradient";
- auto grad_it = gradients.find(id);
- if (grad_it != gradients.end()) {
- for (auto g : grad_it->second) {
- vspace.DeleteGradient(g);
- }
- gradients.erase(grad_it);
- }
- continue;
- }
- const int64 op_id = tape_it->second;
- if (op_id == -1) {
- VLOG(1) << "Tensor " << id << " is source";
- continue;
- }
- auto missing_it = state.op_missing_tensor.find(op_id);
- if (missing_it != state.op_missing_tensor.end()) {
- missing_it->second--;
- VLOG(1) << "Op " << op_id << " missing " << missing_it->second
- << " output gradients";
- if (missing_it->second == 0) {
- op_stack.push_back(op_id);
- }
- }
- }
- }
- CHECK(state.op_tape.empty());
- result->reserve(source_tensor_ids.size());
- for (auto is : source_tensor_ids) {
- auto grad_it = gradients.find(is);
- if (grad_it == gradients.end()) {
- result->push_back(nullptr);
- } else {
- if (grad_it->second.size() == 1) {
- result->push_back(grad_it->second[0]);
- } else {
- result->push_back(vspace.AggregateGradients(grad_it->second));
- }
- gradients.erase(grad_it);
- }
- }
- VLOG(1) << "Final gradients size: " << gradients.size();
- for (auto grad_pair : gradients) {
- for (const auto& g : grad_pair.second) {
- vspace.DeleteGradient(g);
- }
- }
- return Status::OK();
-}
-
} // namespace eager
} // namespace tensorflow