diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-09-19 14:54:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 14:57:27 -0700 |
commit | c3014ec19e23e4aad7286b3fac6b25a5fb4a6326 (patch) | |
tree | 1ffab9c512fe71884a19851b63c1025d487dbe3e /tensorflow/c | |
parent | 4e7d5f008be62bb7ca3e1646af8d4f22287d9e50 (diff) |
Allow the tape tensor to have unknown shapes.
This is done by making the TapeTensor a template rather than a concrete struct.
PiperOrigin-RevId: 213700425
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/eager/tape.h | 118 |
1 files changed, 56 insertions, 62 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 49990b6249..41b5b8ff36 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -29,15 +29,8 @@ limitations under the License. namespace tensorflow { namespace eager { -// Information about a tensor. -struct TapeTensor { - int64 id; // Expected to be unique in the lifetime of this process. - DataType dtype; - TensorShape shape; -}; - // Represents an entry in the tape. -template <typename BackwardFunction> +template <typename BackwardFunction, typename TapeTensor> struct OpTapeEntry { string op_type; std::vector<TapeTensor> output_tensor_info; @@ -57,8 +50,8 @@ struct OpTapeEntry { using TensorTape = gtl::FlatMap<int64, int64>; // Map from operation-id to tape entry. -template <typename BackwardFunction> -using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>; +template <typename BackwardFunction, typename TapeTensor> +using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction, TapeTensor>>; // Operations the tape needs to perform on tensors to do backpropagation. Named // "vspace" because a subset of these are related to a vector space, such as @@ -79,7 +72,7 @@ using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>; // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle // specialization, which is blocked by quite a few things needing to loop back // into python now. -template <typename Gradient, typename BackwardFunction> +template <typename Gradient, typename BackwardFunction, typename TapeTensor> class VSpace { public: virtual ~VSpace() {} @@ -93,10 +86,10 @@ class VSpace { gtl::ArraySlice<Gradient*> gradient_tensors) const = 0; // Returns a tensor of the right shape and dtype filled with zeros. - virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Zeros(const TapeTensor& tensor) const = 0; // Returns a Tensor which is filled with ones and like the input. - virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Ones(const TapeTensor& tensor) const = 0; // Calls the passed-in backward function. virtual Status CallBackwardFunction( @@ -114,7 +107,7 @@ class VSpace { // Traces the execution of operations, doing eager garbage collection, and // exporting a full trace so other code can do backpropagation. Not thread-safe. -template <typename Gradient, typename BackwardFunction> +template <typename Gradient, typename BackwardFunction, typename TapeTensor> class GradientTape { public: // If `persistent` is true, GradientTape will not eagerly delete backward @@ -134,7 +127,7 @@ class GradientTape { void Watch(int64 tensor_id); void RecordOperation( - const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, + const string& op_type, std::vector<TapeTensor>& output_tensors, gtl::ArraySlice<int64> input_tensor_id, gtl::ArraySlice<tensorflow::DataType> input_dtypes, BackwardFunction* backward_function, @@ -146,17 +139,18 @@ class GradientTape { // once) and produces the gradient of the target tensors with respect to the // source tensors. The output gradients are used if not empty and not // null. The result is populated with one tensor per target element. - Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace, - gtl::ArraySlice<int64> target_tensor_ids, - gtl::ArraySlice<int64> source_tensor_id, - gtl::ArraySlice<Gradient*> output_gradients, - std::vector<Gradient*>* result); + Status ComputeGradient( + const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, + gtl::ArraySlice<int64> target_tensor_ids, + gtl::ArraySlice<int64> source_tensor_id, + gtl::ArraySlice<Gradient*> output_gradients, + std::vector<Gradient*>* result); bool IsPersistent() const { return persistent_; } private: TensorTape tensor_tape_; - OpTape<BackwardFunction> op_tape_; + OpTape<BackwardFunction, TapeTensor> op_tape_; int64 next_op_id_{0}; // Map from tensor id to number of remaining usages (i.e. how many entries in @@ -186,8 +180,8 @@ inline bool IsDtypeTrainable(DataType dtype) { } } -template <typename Gradient, typename BackwardFunction> -bool GradientTape<Gradient, BackwardFunction>::ShouldRecord( +template <typename Gradient, typename BackwardFunction, typename TapeTensor> +bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord( gtl::ArraySlice<int64> tensor_ids, gtl::ArraySlice<tensorflow::DataType> dtypes) { CHECK_EQ(tensor_ids.size(), dtypes.size()); @@ -201,14 +195,15 @@ bool GradientTape<Gradient, BackwardFunction>::ShouldRecord( return false; } -template <typename Gradient, typename BackwardFunction> -void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) { +template <typename Gradient, typename BackwardFunction, typename TapeTensor> +void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch( + int64 tensor_id) { tensor_tape_.emplace(tensor_id, -1); } -template <typename Gradient, typename BackwardFunction> -void GradientTape<Gradient, BackwardFunction>::RecordOperation( - const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, +template <typename Gradient, typename BackwardFunction, typename TapeTensor> +void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation( + const string& op_type, std::vector<TapeTensor>& output_tensors, gtl::ArraySlice<int64> input_tensor_id, gtl::ArraySlice<tensorflow::DataType> input_dtypes, BackwardFunction* backward_function, @@ -229,16 +224,18 @@ void GradientTape<Gradient, BackwardFunction>::RecordOperation( for (const TapeTensor& o : output_tensors) { // Note: the tensor can have already been watched and hence be in the tape, // so we cannot check that we're inserting it here. - tensor_tape_[o.id] = op_id; - tensor_usage_[o.id] = 1; + tensor_tape_[o.GetID()] = op_id; + tensor_usage_[o.GetID()] = 1; tensors.push_back(o); } - op_tape_[op_id] = OpTapeEntry<BackwardFunction>{ - op_type, tensors, ids, backward_function, backward_function_deleter}; + op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{ + op_type, std::move(tensors), ids, backward_function, + backward_function_deleter}; } -template <typename Gradient, typename BackwardFunction> -void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) { +template <typename Gradient, typename BackwardFunction, typename TapeTensor> +void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace( + int64 tensor_id) { auto it = tensor_usage_.find(tensor_id); if (it == tensor_usage_.end()) { return; @@ -261,7 +258,7 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) { auto op_it = op_tape_.find(op_id); CHECK(op_it != op_tape_.end()); for (const auto& output : op_it->second.output_tensor_info) { - if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) { // Found a usage for an output, so cannot delete the op. return; } @@ -304,9 +301,9 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) { namespace { -template <typename BackwardFunction> +template <typename BackwardFunction, typename TapeTensor> struct BackpropInitialState { - OpTape<BackwardFunction> op_tape; + OpTape<BackwardFunction, TapeTensor> op_tape; // Map from tensor ID to how many references still exist for this tensor in // the tape. @@ -322,17 +319,17 @@ struct BackpropInitialState { // If `persistent_tape` is false, op_tape is cleared and backwards functions // not needed for gradient computation are deleted. Backwards functions that // are needed, are copied and returned in BackpropInitialState. -template <typename BackwardFunction> -BackpropInitialState<BackwardFunction> PrepareBackprop( +template <typename BackwardFunction, typename TapeTensor> +BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop( gtl::ArraySlice<int64> target, const TensorTape& tensor_tape, - OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set, - bool persistent_tape) { + OpTape<BackwardFunction, TapeTensor>* op_tape, + const gtl::FlatSet<int64>& sources_set, bool persistent_tape) { std::vector<int64> tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { tensor_stack.push_back(t); } - BackpropInitialState<BackwardFunction> result; + BackpropInitialState<BackwardFunction, TapeTensor> result; while (!tensor_stack.empty()) { int64 tensor_id = tensor_stack.back(); tensor_stack.pop_back(); @@ -383,9 +380,9 @@ BackpropInitialState<BackwardFunction> PrepareBackprop( return result; } -template <typename BackwardFunction> +template <typename BackwardFunction, typename TapeTensor> std::vector<int64> InitialStack( - const OpTape<BackwardFunction>& op_tape, + const OpTape<BackwardFunction, TapeTensor>& op_tape, const gtl::FlatMap<int64, int64>& op_missing_tensor) { std::vector<int64> result; for (auto& op_entry : op_tape) { @@ -396,13 +393,13 @@ std::vector<int64> InitialStack( return result; } -template <typename Gradient, typename BackwardFunction> -Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace, - gtl::ArraySlice<int64> target_tensor_ids, - gtl::ArraySlice<Gradient*> output_gradients, - const TensorTape& tensor_tape, - const OpTape<BackwardFunction>& op_tape, - gtl::FlatMap<int64, std::vector<Gradient*>>* result) { +template <typename Gradient, typename BackwardFunction, typename TapeTensor> +Status InitialGradients( + const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, + gtl::ArraySlice<int64> target_tensor_ids, + gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape, + const OpTape<BackwardFunction, TapeTensor>& op_tape, + gtl::FlatMap<int64, std::vector<Gradient*>>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; if (output_gradients.empty() || output_gradients[i] == nullptr) { @@ -416,11 +413,10 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace, } bool found = false; for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { - if (op_it->second.output_tensor_info[j].id == id) { + if (op_it->second.output_tensor_info[j].GetID() == id) { found = true; (*result)[id].push_back( - vspace.Ones(op_it->second.output_tensor_info[j].shape, - op_it->second.output_tensor_info[j].dtype)); + vspace.Ones(op_it->second.output_tensor_info[j])); break; } } @@ -469,16 +465,16 @@ gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() { constexpr int kMinAggregateCount = 4; constexpr int kMinAggregateBytes = 128 * 1024 * 1024; -template <typename Gradient, typename BackwardFunction> -Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( - const VSpace<Gradient, BackwardFunction>& vspace, +template <typename Gradient, typename BackwardFunction, typename TapeTensor> +Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient( + const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, gtl::ArraySlice<int64> target_tensor_ids, gtl::ArraySlice<int64> source_tensor_ids, gtl::ArraySlice<Gradient*> output_gradients, std::vector<Gradient*>* result) { gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); - BackpropInitialState<BackwardFunction> state = PrepareBackprop( + BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop( target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector<int64> op_stack = InitialStack(state.op_tape, state.op_missing_tensor); @@ -522,7 +518,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( out_gradients.reserve(trace.output_tensor_info.size()); bool any_gradient_nonzero = false; for (int i = 0; i < trace.output_tensor_info.size(); ++i) { - const int64 id = trace.output_tensor_info[i].id; + const int64 id = trace.output_tensor_info[i].GetID(); auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { auto func_name_it = @@ -531,9 +527,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { - out_gradients.push_back( - vspace.Zeros(trace.output_tensor_info[i].shape, - trace.output_tensor_info[i].dtype)); + out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i])); } } else { any_gradient_nonzero = true; |