diff options
author | Alexandre Passos <apassos@google.com> | 2017-11-09 07:37:15 -0800 |
---|---|---|
committer | Andrew Selle <aselle@andyselle.com> | 2017-11-10 16:14:38 -0800 |
commit | c5a7366bfef9cfb00cd9855c98c12c6005dbb1bb (patch) | |
tree | 0f50da9cdddf92d5318f8477ea76e4e067fa8e7b /tensorflow/c | |
parent | 31d6b687da35fcbf4a1dd767fda06fab9213db31 (diff) |
Removes void*s from the tape gradient code, replacing with templates.
PiperOrigin-RevId: 175155685
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/c/eager/tape.cc | 410 | ||||
-rw-r--r-- | tensorflow/c/eager/tape.h | 473 |
3 files changed, 449 insertions, 435 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 74e94be8d6..d533758e36 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -106,7 +106,6 @@ tf_cc_test( cc_library( name = "tape", - srcs = ["tape.cc"], hdrs = ["tape.h"], visibility = ["//tensorflow:internal"], deps = [ diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc deleted file mode 100644 index 459499bb69..0000000000 --- a/tensorflow/c/eager/tape.cc +++ /dev/null @@ -1,410 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <unordered_set> - -#include "tensorflow/c/eager/tape.h" - -namespace tensorflow { -namespace eager { - -bool GradientTape::ShouldRecord(gtl::ArraySlice<int64> tensor_ids) { - for (int64 i : tensor_ids) { - if (tensor_tape_.find(i) != tensor_tape_.end()) { - return true; - } - } - return false; -} - -void GradientTape::Watch(int64 tensor_id) { - tensor_tape_.emplace(tensor_id, -1); -} - -void GradientTape::RecordOperation( - const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, - gtl::ArraySlice<int64> input_tensor_id, void* backward_function, - const std::function<void()>& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id)) { - backward_function_deleter(); - return; - } - std::vector<int64> ids; - ids.reserve(input_tensor_id.size()); - for (int64 i : input_tensor_id) { - tensor_usage_[i]++; - ids.push_back(i); - } - const int64 op_id = next_op_id_++; - std::vector<TapeTensor> tensors; - tensors.reserve(output_tensors.size()); - for (const TapeTensor& o : output_tensors) { - // Note: the tensor can have already been watched and hence be in the tape, - // so we cannot check that we're inserting it here. - tensor_tape_[o.id] = op_id; - tensor_usage_[o.id] = 1; - tensors.push_back(o); - } - op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function, - backward_function_deleter}; -} - -void GradientTape::DeleteTrace(int64 tensor_id) { - auto it = tensor_usage_.find(tensor_id); - if (it == tensor_usage_.end()) { - return; - } - it->second--; - if (it->second != 0) { - return; - } - tensor_usage_.erase(it); - auto tensor_op_it = tensor_tape_.find(tensor_id); - if (tensor_op_it == tensor_tape_.end()) { - return; - } - const int64 op_id = tensor_op_it->second; - if (op_id == -1) { - // Do not delete watched tensors. - return; - } - tensor_tape_.erase(tensor_op_it); - auto op_it = op_tape_.find(op_id); - CHECK(op_it != op_tape_.end()); - for (const auto& output : op_it->second.output_tensor_info) { - if (tensor_usage_.find(output.id) != tensor_usage_.end()) { - // Found a usage for an output, so cannot delete the op. - return; - } - } - for (int64 id : op_it->second.input_tensor_id) { - DeleteTrace(id); - } - op_it->second.backward_function_deleter(); - op_tape_.erase(op_it); -} - -// Terminology: -// -// - op: a possibly composite operation, which has an entry in the tape -// - target: dy in dx/dy -// - source: dx in dx/dy -// - tensor: one of the many inputs or outputs of an operation -// -// Below here we do the gradient algorithm. It works as follows: -// -// First we filter the tape to just the subset of operations we want to -// differentiate. In the process of doing so we count how many times each Tensor -// is used as an input to an op (so we know when we're done computing gradients -// for that Tensor). We also count, for each tape entry, how many of its output -// Tensors need gradients to be computed (Tensors which are not used do not need -// any gradients to be computed). -// -// Finally, we start a backprop stack with a set of tape entries for which we -// have all gradients available. This set usually is a subset of the set of -// targets (not all since targets which have outputs in the tape will not have -// gradients available initially). -// -// Then we repeatedly pop an entry from the stack, run its backprop, and update -// the gradients of its inputs. Once we have computed all gradients for a single -// input we can mark this input as done, and this can trigger adding an entry to -// the stack if all outputs of that entry are now done. -// -// When the stack is empty we have gradients for all tensors we're interested -// in. - -struct BackpropInitialState { - OpTape op_tape; - - // Map from tensor ID to how many references still exist for this tensor in - // the tape. - std::unordered_map<int64, int64> tensor_usage_counts; - - // Maps from op ID to how many output tensors of this op still need to have - // their gradients computed. - std::unordered_map<int64, int64> op_missing_tensor; -}; - -BackpropInitialState PrepareBackprop( - gtl::ArraySlice<int64> target, const TensorTape& tensor_tape, - OpTape op_tape, const std::unordered_set<int64>& sources_set) { - std::vector<int64> tensor_stack; - tensor_stack.reserve(target.size()); - for (auto t : target) { - tensor_stack.push_back(t); - } - BackpropInitialState result; - while (!tensor_stack.empty()) { - int64 tensor_id = tensor_stack.back(); - tensor_stack.pop_back(); - auto op_id_it = tensor_tape.find(tensor_id); - if (op_id_it == tensor_tape.end()) { - continue; - } - int64 op_id = op_id_it->second; - auto op_it = op_tape.find(op_id); - auto result_op_it = result.op_tape.find(op_id); - if (op_id == -1 || op_it == op_tape.end() || - result_op_it != result.op_tape.end()) { - continue; - } - CHECK(result.op_tape.emplace(op_id, op_it->second).second); - for (auto it : op_it->second.input_tensor_id) { - auto count_it = result.tensor_usage_counts.find(it); - if (count_it != result.tensor_usage_counts.end()) { - count_it->second++; - } else { - result.tensor_usage_counts[it] = 1; - if (sources_set.find(it) == sources_set.end() && - tensor_tape.find(it) != tensor_tape.end()) { - tensor_stack.push_back(it); - } - } - } - op_tape.erase(op_it); - } - for (auto& pair : result.tensor_usage_counts) { - auto it = tensor_tape.find(pair.first); - if (it != tensor_tape.end() && it->second != -1) { - result.op_missing_tensor[it->second] += 1; - } - } - // Call destructors for all unneeded gradient functions. - for (const auto& op_pair : op_tape) { - op_pair.second.backward_function_deleter(); - } - return result; -} - -std::vector<int64> InitialStack( - const OpTape& op_tape, - const std::unordered_map<int64, int64>& op_missing_tensor) { - std::vector<int64> result; - for (auto& op_entry : op_tape) { - if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) { - result.push_back(op_entry.first); - } - } - return result; -} - -Status InitialGradients(const VSpace& vspace, gtl::ArraySlice<void*> target, - gtl::ArraySlice<void*> output_gradients, - std::unordered_map<int64, int64> tensor_usage_counts, - std::unordered_map<int64, std::vector<void*>>* result) { - for (int i = 0; i < target.size(); ++i) { - int64 id = vspace.TensorId(target[i]); - if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { - if (!output_gradients.empty() && output_gradients[i] != nullptr) { - // TODO(apassos) figure out how to print debugging information here. - return errors::InvalidArgument( - "A gradient was provided for a tensor which is used as part of the " - "computation."); - } - } else { - if (output_gradients.empty() || output_gradients[i] == nullptr) { - (*result)[id].push_back(vspace.OnesLike(target[i])); - } else { - (*result)[id].push_back(output_gradients[i]); - } - } - } - return Status::OK(); -} - -// If over kMinAggregateCount gradients are accumulated and the total -// memory consumption is over kMinAggregateBytes, do an early aggregation -// so as to release the gradient tensor to save memory. -static const int kMinAggregateCount = 4; -static const int kMinAggregateBytes = 128 * 1024 * 1024; - -Status GradientTape::Gradient(const VSpace& vspace, - gtl::ArraySlice<void*> target, - gtl::ArraySlice<void*> sources, - gtl::ArraySlice<void*> output_gradients, - std::vector<void*>* result) { - std::vector<int64> id_sources; - id_sources.reserve(sources.size()); - for (void* s : sources) { - id_sources.push_back(vspace.TensorId(s)); - } - std::unordered_set<int64> sources_set(id_sources.begin(), id_sources.end()); - std::vector<int64> id_targets; - id_sources.reserve(target.size()); - for (void* t : target) { - id_targets.push_back(vspace.TensorId(t)); - } - BackpropInitialState state = PrepareBackprop( - id_targets, tensor_tape_, std::move(op_tape_), sources_set); - std::vector<int64> op_stack = - InitialStack(state.op_tape, state.op_missing_tensor); - std::unordered_map<int64, std::vector<void*>> gradients; - Status s = InitialGradients(vspace, target, output_gradients, - state.tensor_usage_counts, &gradients); - auto cleanup = [&state]() { - // Release all backprop functions - for (const auto& pair : state.op_tape) { - pair.second.backward_function_deleter(); - } - }; - if (!s.ok()) { - cleanup(); - return s; - } - std::unordered_map<int64, int64> gradients_size; - // TODO(apassos) multiple threads could be dequeuing from op_stack at the same - // time, for better CPU backprop performance. - VLOG(1) << "Initial stack:"; - if (VLOG_IS_ON(1)) { - for (auto t : op_stack) { - VLOG(1) << " " << t; - } - } - std::unordered_map<string, std::unordered_set<int>> - functions_accept_none_for_indices({ - {"SoftmaxCrossEntropyWithLogits", {1}}, - {"FusedBatchNorm", {1, 2, 3, 4}}, - }); - while (!op_stack.empty()) { - const int64 op = op_stack.back(); - VLOG(1) << "Popped " << op; - op_stack.pop_back(); - auto op_it = state.op_tape.find(op); - if (op_it == state.op_tape.end()) { - // It is possible for ops to end up on the stack if they are unrelated to - // the target; we should just skip them. - continue; - } - auto trace = std::move(op_it->second); - state.op_tape.erase(op_it); - std::vector<void*> out_gradients; - out_gradients.reserve(trace.output_tensor_info.size()); - for (int i = 0; i < trace.output_tensor_info.size(); ++i) { - const int64 id = trace.output_tensor_info[i].id; - auto grad_it = gradients.find(id); - if (grad_it == gradients.end()) { - auto func_name_it = - functions_accept_none_for_indices.find(trace.op_type); - if (func_name_it != functions_accept_none_for_indices.end() && - func_name_it->second.find(i) != func_name_it->second.end()) { - out_gradients.push_back(nullptr); - } else { - out_gradients.push_back( - vspace.Zeros(trace.output_tensor_info[i].shape, - trace.output_tensor_info[i].dtype)); - } - } else { - out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); - if (sources_set.find(grad_it->first) == sources_set.end()) { - gradients.erase(grad_it); - } - } - } - std::vector<void*> in_gradients; - Status s = vspace.CallBackwardFunction(trace.backward_function, - out_gradients, &in_gradients); - if (!s.ok()) { - VLOG(1) << "Gradient function failed."; - cleanup(); - return s; - } - VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " - << trace.input_tensor_id.size() << " sources"; - for (int i = 0; i < in_gradients.size(); ++i) { - const int64 id = trace.input_tensor_id[i]; - if (in_gradients[i] != nullptr) { - auto& unaggregated_grads = gradients[id]; - unaggregated_grads.push_back(in_gradients[i]); - if (unaggregated_grads.size() > kMinAggregateCount) { - auto size_it = gradients_size.find(id); - int64 size; - if (size_it == gradients_size.end()) { - size = vspace.NumElements(unaggregated_grads[0]); - gradients_size.emplace(id, size); - } else { - size = size_it->second; - } - if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) { - void* tensor = vspace.AggregateGradients(unaggregated_grads); - unaggregated_grads.clear(); - unaggregated_grads.push_back(tensor); - } - } - } - auto usage_count_it = state.tensor_usage_counts.find(id); - if (usage_count_it == state.tensor_usage_counts.end()) { - VLOG(1) << "Tensor " << id << " not used"; - continue; - } - usage_count_it->second--; - if (usage_count_it->second > 0) { - VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second; - continue; - } - auto tape_it = tensor_tape_.find(id); - if (tape_it == tensor_tape_.end()) { - VLOG(1) << "Tensor " << id - << " has no associated op. Deleting gradient"; - auto grad_it = gradients.find(id); - if (grad_it != gradients.end()) { - for (auto g : grad_it->second) { - vspace.DeleteTensor(g); - } - gradients.erase(grad_it); - } - continue; - } - const int64 op_id = tape_it->second; - if (op_id == -1) { - VLOG(1) << "Tensor " << id << " is source"; - continue; - } - auto missing_it = state.op_missing_tensor.find(op_id); - if (missing_it != state.op_missing_tensor.end()) { - missing_it->second--; - VLOG(1) << "Op " << op_id << " missing " << missing_it->second - << " output gradients"; - if (missing_it->second == 0) { - op_stack.push_back(op_id); - } - } - } - } - CHECK(state.op_tape.empty()); - result->reserve(sources.size()); - for (auto is : id_sources) { - auto grad_it = gradients.find(is); - if (grad_it == gradients.end()) { - result->push_back(nullptr); - } else { - if (grad_it->second.size() == 1) { - result->push_back(grad_it->second[0]); - } else { - result->push_back(vspace.AggregateGradients(grad_it->second)); - } - gradients.erase(grad_it); - } - } - VLOG(1) << "Final gradients size: " << gradients.size(); - for (auto grad_pair : gradients) { - for (const auto& g : grad_pair.second) { - vspace.DeleteTensor(g); - } - } - return Status::OK(); -} - -} // namespace eager -} // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 2bb62a7ab3..654ceb7bec 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -19,6 +19,7 @@ limitations under the License. // maintains the data structures required to do so. #include <unordered_map> +#include <unordered_set> #include <vector> #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -36,13 +37,14 @@ struct TapeTensor { }; // Represents an entry in the tape. +template <typename BackwardFunction> struct OpTapeEntry { string op_type; std::vector<TapeTensor> output_tensor_info; std::vector<int64> input_tensor_id; // TODO(apassos) consider narrowing down this interface. - void* backward_function; + BackwardFunction* backward_function; // Should be called before deleting the backward function. TODO(apassos) use // unique_ptrs to ensure this happens. @@ -55,51 +57,67 @@ struct OpTapeEntry { using TensorTape = std::unordered_map<int64, int64>; // Map from operation-id to tape entry. -using OpTape = std::unordered_map<int64, OpTapeEntry>; +template <typename BackwardFunction> +using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>; // Operations the tape needs to perform on tensors to do backpropagation. Named // "vspace" because a subset of these are related to a vector space, such as // adding gradients, getting zeroes, etc. Currently cannot be implemented // without using tensorflow python code, hence left unspecified here. // -// We currently use void* for tensors, backward functions, and gradients (which -// can be but are not required to be tensors). TODO(apassos) replace this first -// with templates to allow for pyobject specialization in the client followed by -// a TFE_TensorHandle specialization, which is blocked by quite a few things -// still. +// Tensor is a representation of a tensor. We need to take its ID, and it needs +// to match IDs in the tape. +// +// Gradient is the type returned by gradient functions. In Python TF it's either +// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need +// to allow their size to be computed and they need to be passable to a backward +// function and deleted (as the backprop code creates lots of gradients the user +// is not interested in). +// +// BackwardFunction needs to be a closure which stores intermediate activations +// from the forward computation and calls a vector-jacobian product function +// (also known as adjoint function) to compute, given downstream gradients, +// upstream gradients. +// +// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle +// specialization, which is blocked by quite a few things needing to loop back +// into python now. +template <typename Tensor, typename Gradient, typename BackwardFunction> class VSpace { public: virtual ~VSpace() {} - // Returns the number of elements in the tensor. - virtual int64 NumElements(void* tensor) const = 0; + // Returns the number of elements in the gradient tensor. + virtual int64 NumElements(Gradient* tensor) const = 0; // Consumes references to the tensors in the gradient_tensors list and returns // a tensor with the result. - virtual void* AggregateGradients( - gtl::ArraySlice<void*> gradient_tensors) const = 0; + virtual Gradient* AggregateGradients( + gtl::ArraySlice<Gradient*> gradient_tensors) const = 0; // Returns a tensor of the right shape and dtype filled with zeros. - virtual void* Zeros(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0; // Returns a Tensor which is filled with ones and like the input. - virtual void* OnesLike(void*) const = 0; + virtual Gradient* OnesLike(Tensor*) const = 0; // Returns an integer which is a unique-to-within-this-program handle for this // tensor. - virtual int64 TensorId(void* tensor) const = 0; + virtual int64 TensorId(Tensor* tensor) const = 0; // Calls the passed-in backward function. - virtual Status CallBackwardFunction(void* backward_function, - gtl::ArraySlice<void*> output_gradients, - std::vector<void*>* result) const = 0; + virtual Status CallBackwardFunction( + BackwardFunction* backward_function, + gtl::ArraySlice<Gradient*> output_gradients, + std::vector<Gradient*>* result) const = 0; // Deletes the input tensor. - virtual void DeleteTensor(void* tensor) const = 0; + virtual void DeleteGradient(Gradient* gradient) const = 0; }; // Traces the execution of operations, doing eager garbage collection, and // exporting a full trace so other code can do backpropagation. Not thread-safe. +template <typename Tensor, typename Gradient, typename BackwardFunction> class GradientTape { public: GradientTape() {} @@ -116,7 +134,7 @@ class GradientTape { void RecordOperation(const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, gtl::ArraySlice<int64> input_tensor_id, - void* backward_function, + BackwardFunction* backward_function, const std::function<void()>& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -125,14 +143,15 @@ class GradientTape { // once) and produces the gradient of the target tensors with respect to the // source tensors. The output gradients are used if not empty and not // null. The result is populated with one tensor per target element. - Status Gradient(const VSpace& vspace, gtl::ArraySlice<void*> target, - gtl::ArraySlice<void*> sources, - gtl::ArraySlice<void*> output_gradients, - std::vector<void*>* result); + Status ComputeGradient( + const VSpace<Tensor, Gradient, BackwardFunction>& vspace, + gtl::ArraySlice<Tensor*> target, gtl::ArraySlice<Tensor*> sources, + gtl::ArraySlice<Gradient*> output_gradients, + std::vector<Gradient*>* result); private: TensorTape tensor_tape_; - OpTape op_tape_; + OpTape<BackwardFunction> op_tape_; int64 next_op_id_{0}; // Map from tensor id to number of remaining usages (i.e. how many entries in @@ -140,6 +159,412 @@ class GradientTape { std::unordered_map<int64, int64> tensor_usage_; }; +// Template instantiations here + +template <typename Tensor, typename Gradient, typename BackwardFunction> +bool GradientTape<Tensor, Gradient, BackwardFunction>::ShouldRecord( + gtl::ArraySlice<int64> tensor_ids) { + for (int64 i : tensor_ids) { + if (tensor_tape_.find(i) != tensor_tape_.end()) { + return true; + } + } + return false; +} + +template <typename Tensor, typename Gradient, typename BackwardFunction> +void GradientTape<Tensor, Gradient, BackwardFunction>::Watch(int64 tensor_id) { + tensor_tape_.emplace(tensor_id, -1); +} + +template <typename Tensor, typename Gradient, typename BackwardFunction> +void GradientTape<Tensor, Gradient, BackwardFunction>::RecordOperation( + const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, + gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function, + const std::function<void()>& backward_function_deleter) { + if (!ShouldRecord(input_tensor_id)) { + backward_function_deleter(); + return; + } + std::vector<int64> ids; + ids.reserve(input_tensor_id.size()); + for (int64 i : input_tensor_id) { + tensor_usage_[i]++; + ids.push_back(i); + } + const int64 op_id = next_op_id_++; + std::vector<TapeTensor> tensors; + tensors.reserve(output_tensors.size()); + for (const TapeTensor& o : output_tensors) { + // Note: the tensor can have already been watched and hence be in the tape, + // so we cannot check that we're inserting it here. + tensor_tape_[o.id] = op_id; + tensor_usage_[o.id] = 1; + tensors.push_back(o); + } + op_tape_[op_id] = OpTapeEntry<BackwardFunction>{ + op_type, tensors, ids, backward_function, backward_function_deleter}; +} + +template <typename Tensor, typename Gradient, typename BackwardFunction> +void GradientTape<Tensor, Gradient, BackwardFunction>::DeleteTrace( + int64 tensor_id) { + auto it = tensor_usage_.find(tensor_id); + if (it == tensor_usage_.end()) { + return; + } + it->second--; + if (it->second != 0) { + return; + } + tensor_usage_.erase(it); + auto tensor_op_it = tensor_tape_.find(tensor_id); + if (tensor_op_it == tensor_tape_.end()) { + return; + } + const int64 op_id = tensor_op_it->second; + if (op_id == -1) { + // Do not delete watched tensors. + return; + } + tensor_tape_.erase(tensor_op_it); + auto op_it = op_tape_.find(op_id); + CHECK(op_it != op_tape_.end()); + for (const auto& output : op_it->second.output_tensor_info) { + if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + // Found a usage for an output, so cannot delete the op. + return; + } + } + for (int64 id : op_it->second.input_tensor_id) { + DeleteTrace(id); + } + op_it->second.backward_function_deleter(); + op_tape_.erase(op_it); +} + +// Terminology: +// +// - op: a possibly composite operation, which has an entry in the tape +// - target: dy in dx/dy +// - source: dx in dx/dy +// - tensor: one of the many inputs or outputs of an operation +// +// Below here we do the gradient algorithm. It works as follows: +// +// First we filter the tape to just the subset of operations we want to +// differentiate. In the process of doing so we count how many times each Tensor +// is used as an input to an op (so we know when we're done computing gradients +// for that Tensor). We also count, for each tape entry, how many of its output +// Tensors need gradients to be computed (Tensors which are not used do not need +// any gradients to be computed). +// +// Finally, we start a backprop stack with a set of tape entries for which we +// have all gradients available. This set usually is a subset of the set of +// targets (not all since targets which have outputs in the tape will not have +// gradients available initially). +// +// Then we repeatedly pop an entry from the stack, run its backprop, and update +// the gradients of its inputs. Once we have computed all gradients for a single +// input we can mark this input as done, and this can trigger adding an entry to +// the stack if all outputs of that entry are now done. +// +// When the stack is empty we have gradients for all tensors we're interested +// in. + +namespace { + +template <typename BackwardFunction> +struct BackpropInitialState { + OpTape<BackwardFunction> op_tape; + + // Map from tensor ID to how many references still exist for this tensor in + // the tape. + std::unordered_map<int64, int64> tensor_usage_counts; + + // Maps from op ID to how many output tensors of this op still need to have + // their gradients computed. + std::unordered_map<int64, int64> op_missing_tensor; +}; + +template <typename BackwardFunction> +BackpropInitialState<BackwardFunction> PrepareBackprop( + gtl::ArraySlice<int64> target, const TensorTape& tensor_tape, + OpTape<BackwardFunction> op_tape, + const std::unordered_set<int64>& sources_set) { + std::vector<int64> tensor_stack; + tensor_stack.reserve(target.size()); + for (auto t : target) { + tensor_stack.push_back(t); + } + BackpropInitialState<BackwardFunction> result; + while (!tensor_stack.empty()) { + int64 tensor_id = tensor_stack.back(); + tensor_stack.pop_back(); + auto op_id_it = tensor_tape.find(tensor_id); + if (op_id_it == tensor_tape.end()) { + continue; + } + int64 op_id = op_id_it->second; + auto op_it = op_tape.find(op_id); + auto result_op_it = result.op_tape.find(op_id); + if (op_id == -1 || op_it == op_tape.end() || + result_op_it != result.op_tape.end()) { + continue; + } + CHECK(result.op_tape.emplace(op_id, op_it->second).second); + for (auto it : op_it->second.input_tensor_id) { + auto count_it = result.tensor_usage_counts.find(it); + if (count_it != result.tensor_usage_counts.end()) { + count_it->second++; + } else { + result.tensor_usage_counts[it] = 1; + if (sources_set.find(it) == sources_set.end() && + tensor_tape.find(it) != tensor_tape.end()) { + tensor_stack.push_back(it); + } + } + } + op_tape.erase(op_it); + } + for (auto& pair : result.tensor_usage_counts) { + auto it = tensor_tape.find(pair.first); + if (it != tensor_tape.end() && it->second != -1) { + result.op_missing_tensor[it->second] += 1; + } + } + // Call destructors for all unneeded gradient functions. + for (const auto& op_pair : op_tape) { + op_pair.second.backward_function_deleter(); + } + return result; +} + +template <typename BackwardFunction> +std::vector<int64> InitialStack( + const OpTape<BackwardFunction>& op_tape, + const std::unordered_map<int64, int64>& op_missing_tensor) { + std::vector<int64> result; + for (auto& op_entry : op_tape) { + if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) { + result.push_back(op_entry.first); + } + } + return result; +} + +template <typename Tensor, typename Gradient, typename BackwardFunction> +Status InitialGradients( + const VSpace<Tensor, Gradient, BackwardFunction>& vspace, + gtl::ArraySlice<Tensor*> target, + gtl::ArraySlice<Gradient*> output_gradients, + std::unordered_map<int64, int64> tensor_usage_counts, + std::unordered_map<int64, std::vector<Gradient*>>* result) { + for (int i = 0; i < target.size(); ++i) { + int64 id = vspace.TensorId(target[i]); + if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { + if (!output_gradients.empty() && output_gradients[i] != nullptr) { + // TODO(apassos) figure out how to print debugging information here. + return errors::InvalidArgument( + "A gradient was provided for a tensor which is used as part of the " + "computation."); + } + } else { + if (output_gradients.empty() || output_gradients[i] == nullptr) { + (*result)[id].push_back(vspace.OnesLike(target[i])); + } else { + (*result)[id].push_back(output_gradients[i]); + } + } + } + return Status::OK(); +} + +} // namespace + +// If over kMinAggregateCount gradients are accumulated and the total +// memory consumption is over kMinAggregateBytes, do an early aggregation +// so as to release the gradient tensor to save memory. +constexpr int kMinAggregateCount = 4; +constexpr int kMinAggregateBytes = 128 * 1024 * 1024; + +template <typename Tensor, typename Gradient, typename BackwardFunction> +Status GradientTape<Tensor, Gradient, BackwardFunction>::ComputeGradient( + const VSpace<Tensor, Gradient, BackwardFunction>& vspace, + gtl::ArraySlice<Tensor*> target, gtl::ArraySlice<Tensor*> sources, + gtl::ArraySlice<Gradient*> output_gradients, + std::vector<Gradient*>* result) { + std::vector<int64> id_sources; + id_sources.reserve(sources.size()); + for (Tensor* s : sources) { + id_sources.push_back(vspace.TensorId(s)); + } + std::unordered_set<int64> sources_set(id_sources.begin(), id_sources.end()); + std::vector<int64> id_targets; + id_sources.reserve(target.size()); + for (Tensor* t : target) { + id_targets.push_back(vspace.TensorId(t)); + } + BackpropInitialState<BackwardFunction> state = PrepareBackprop( + id_targets, tensor_tape_, std::move(op_tape_), sources_set); + std::vector<int64> op_stack = + InitialStack(state.op_tape, state.op_missing_tensor); + std::unordered_map<int64, std::vector<Gradient*>> gradients; + Status s = InitialGradients(vspace, target, output_gradients, + state.tensor_usage_counts, &gradients); + auto cleanup = [&state]() { + // Release all backprop functions + for (const auto& pair : state.op_tape) { + pair.second.backward_function_deleter(); + } + }; + if (!s.ok()) { + cleanup(); + return s; + } + std::unordered_map<int64, int64> gradients_size; + // TODO(apassos) multiple threads could be dequeuing from op_stack at the same + // time, for better CPU backprop performance. + VLOG(1) << "Initial stack:"; + if (VLOG_IS_ON(1)) { + for (auto t : op_stack) { + VLOG(1) << " " << t; + } + } + std::unordered_map<string, std::unordered_set<int>> + functions_accept_none_for_indices({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + while (!op_stack.empty()) { + const int64 op = op_stack.back(); + VLOG(1) << "Popped " << op; + op_stack.pop_back(); + auto op_it = state.op_tape.find(op); + if (op_it == state.op_tape.end()) { + // It is possible for ops to end up on the stack if they are unrelated to + // the target; we should just skip them. + continue; + } + auto trace = std::move(op_it->second); + state.op_tape.erase(op_it); + std::vector<Gradient*> out_gradients; + out_gradients.reserve(trace.output_tensor_info.size()); + for (int i = 0; i < trace.output_tensor_info.size(); ++i) { + const int64 id = trace.output_tensor_info[i].id; + auto grad_it = gradients.find(id); + if (grad_it == gradients.end()) { + auto func_name_it = + functions_accept_none_for_indices.find(trace.op_type); + if (func_name_it != functions_accept_none_for_indices.end() && + func_name_it->second.find(i) != func_name_it->second.end()) { + out_gradients.push_back(nullptr); + } else { + out_gradients.push_back( + vspace.Zeros(trace.output_tensor_info[i].shape, + trace.output_tensor_info[i].dtype)); + } + } else { + out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + if (sources_set.find(grad_it->first) == sources_set.end()) { + gradients.erase(grad_it); + } + } + } + std::vector<Gradient*> in_gradients; + Status s = vspace.CallBackwardFunction(trace.backward_function, + out_gradients, &in_gradients); + if (!s.ok()) { + VLOG(1) << "Gradient function failed."; + cleanup(); + return s; + } + VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " + << trace.input_tensor_id.size() << " sources"; + for (int i = 0; i < in_gradients.size(); ++i) { + const int64 id = trace.input_tensor_id[i]; + if (in_gradients[i] != nullptr) { + auto& unaggregated_grads = gradients[id]; + unaggregated_grads.push_back(in_gradients[i]); + if (unaggregated_grads.size() > kMinAggregateCount) { + auto size_it = gradients_size.find(id); + int64 size; + if (size_it == gradients_size.end()) { + size = vspace.NumElements(unaggregated_grads[0]); + gradients_size.emplace(id, size); + } else { + size = size_it->second; + } + if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) { + Gradient* grad = vspace.AggregateGradients(unaggregated_grads); + unaggregated_grads.clear(); + unaggregated_grads.push_back(grad); + } + } + } + auto usage_count_it = state.tensor_usage_counts.find(id); + if (usage_count_it == state.tensor_usage_counts.end()) { + VLOG(1) << "Tensor " << id << " not used"; + continue; + } + usage_count_it->second--; + if (usage_count_it->second > 0) { + VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second; + continue; + } + auto tape_it = tensor_tape_.find(id); + if (tape_it == tensor_tape_.end()) { + VLOG(1) << "Tensor " << id + << " has no associated op. Deleting gradient"; + auto grad_it = gradients.find(id); + if (grad_it != gradients.end()) { + for (auto g : grad_it->second) { + vspace.DeleteGradient(g); + } + gradients.erase(grad_it); + } + continue; + } + const int64 op_id = tape_it->second; + if (op_id == -1) { + VLOG(1) << "Tensor " << id << " is source"; + continue; + } + auto missing_it = state.op_missing_tensor.find(op_id); + if (missing_it != state.op_missing_tensor.end()) { + missing_it->second--; + VLOG(1) << "Op " << op_id << " missing " << missing_it->second + << " output gradients"; + if (missing_it->second == 0) { + op_stack.push_back(op_id); + } + } + } + } + CHECK(state.op_tape.empty()); + result->reserve(sources.size()); + for (auto is : id_sources) { + auto grad_it = gradients.find(is); + if (grad_it == gradients.end()) { + result->push_back(nullptr); + } else { + if (grad_it->second.size() == 1) { + result->push_back(grad_it->second[0]); + } else { + result->push_back(vspace.AggregateGradients(grad_it->second)); + } + gradients.erase(grad_it); + } + } + VLOG(1) << "Final gradients size: " << gradients.size(); + for (auto grad_pair : gradients) { + for (const auto& g : grad_pair.second) { + vspace.DeleteGradient(g); + } + } + return Status::OK(); +} + } // namespace eager } // namespace tensorflow |