From c3014ec19e23e4aad7286b3fac6b25a5fb4a6326 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Wed, 19 Sep 2018 14:54:07 -0700 Subject: Allow the tape tensor to have unknown shapes. This is done by making the TapeTensor a template rather than a concrete struct. PiperOrigin-RevId: 213700425 --- tensorflow/c/eager/tape.h | 118 ++++++++++++++++++++++------------------------ 1 file changed, 56 insertions(+), 62 deletions(-) (limited to 'tensorflow/c') diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 49990b6249..41b5b8ff36 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -29,15 +29,8 @@ limitations under the License. namespace tensorflow { namespace eager { -// Information about a tensor. -struct TapeTensor { - int64 id; // Expected to be unique in the lifetime of this process. - DataType dtype; - TensorShape shape; -}; - // Represents an entry in the tape. -template +template struct OpTapeEntry { string op_type; std::vector output_tensor_info; @@ -57,8 +50,8 @@ struct OpTapeEntry { using TensorTape = gtl::FlatMap; // Map from operation-id to tape entry. -template -using OpTape = gtl::FlatMap>; +template +using OpTape = gtl::FlatMap>; // Operations the tape needs to perform on tensors to do backpropagation. Named // "vspace" because a subset of these are related to a vector space, such as @@ -79,7 +72,7 @@ using OpTape = gtl::FlatMap>; // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle // specialization, which is blocked by quite a few things needing to loop back // into python now. -template +template class VSpace { public: virtual ~VSpace() {} @@ -93,10 +86,10 @@ class VSpace { gtl::ArraySlice gradient_tensors) const = 0; // Returns a tensor of the right shape and dtype filled with zeros. - virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Zeros(const TapeTensor& tensor) const = 0; // Returns a Tensor which is filled with ones and like the input. - virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Ones(const TapeTensor& tensor) const = 0; // Calls the passed-in backward function. virtual Status CallBackwardFunction( @@ -114,7 +107,7 @@ class VSpace { // Traces the execution of operations, doing eager garbage collection, and // exporting a full trace so other code can do backpropagation. Not thread-safe. -template +template class GradientTape { public: // If `persistent` is true, GradientTape will not eagerly delete backward @@ -134,7 +127,7 @@ class GradientTape { void Watch(int64 tensor_id); void RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, + const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, @@ -146,17 +139,18 @@ class GradientTape { // once) and produces the gradient of the target tensors with respect to the // source tensors. The output gradients are used if not empty and not // null. The result is populated with one tensor per target element. - Status ComputeGradient(const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_id, - gtl::ArraySlice output_gradients, - std::vector* result); + Status ComputeGradient( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_id, + gtl::ArraySlice output_gradients, + std::vector* result); bool IsPersistent() const { return persistent_; } private: TensorTape tensor_tape_; - OpTape op_tape_; + OpTape op_tape_; int64 next_op_id_{0}; // Map from tensor id to number of remaining usages (i.e. how many entries in @@ -186,8 +180,8 @@ inline bool IsDtypeTrainable(DataType dtype) { } } -template -bool GradientTape::ShouldRecord( +template +bool GradientTape::ShouldRecord( gtl::ArraySlice tensor_ids, gtl::ArraySlice dtypes) { CHECK_EQ(tensor_ids.size(), dtypes.size()); @@ -201,14 +195,15 @@ bool GradientTape::ShouldRecord( return false; } -template -void GradientTape::Watch(int64 tensor_id) { +template +void GradientTape::Watch( + int64 tensor_id) { tensor_tape_.emplace(tensor_id, -1); } -template -void GradientTape::RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, +template +void GradientTape::RecordOperation( + const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, @@ -229,16 +224,18 @@ void GradientTape::RecordOperation( for (const TapeTensor& o : output_tensors) { // Note: the tensor can have already been watched and hence be in the tape, // so we cannot check that we're inserting it here. - tensor_tape_[o.id] = op_id; - tensor_usage_[o.id] = 1; + tensor_tape_[o.GetID()] = op_id; + tensor_usage_[o.GetID()] = 1; tensors.push_back(o); } - op_tape_[op_id] = OpTapeEntry{ - op_type, tensors, ids, backward_function, backward_function_deleter}; + op_tape_[op_id] = OpTapeEntry{ + op_type, std::move(tensors), ids, backward_function, + backward_function_deleter}; } -template -void GradientTape::DeleteTrace(int64 tensor_id) { +template +void GradientTape::DeleteTrace( + int64 tensor_id) { auto it = tensor_usage_.find(tensor_id); if (it == tensor_usage_.end()) { return; @@ -261,7 +258,7 @@ void GradientTape::DeleteTrace(int64 tensor_id) { auto op_it = op_tape_.find(op_id); CHECK(op_it != op_tape_.end()); for (const auto& output : op_it->second.output_tensor_info) { - if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) { // Found a usage for an output, so cannot delete the op. return; } @@ -304,9 +301,9 @@ void GradientTape::DeleteTrace(int64 tensor_id) { namespace { -template +template struct BackpropInitialState { - OpTape op_tape; + OpTape op_tape; // Map from tensor ID to how many references still exist for this tensor in // the tape. @@ -322,17 +319,17 @@ struct BackpropInitialState { // If `persistent_tape` is false, op_tape is cleared and backwards functions // not needed for gradient computation are deleted. Backwards functions that // are needed, are copied and returned in BackpropInitialState. -template -BackpropInitialState PrepareBackprop( +template +BackpropInitialState PrepareBackprop( gtl::ArraySlice target, const TensorTape& tensor_tape, - OpTape* op_tape, const gtl::FlatSet& sources_set, - bool persistent_tape) { + OpTape* op_tape, + const gtl::FlatSet& sources_set, bool persistent_tape) { std::vector tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { tensor_stack.push_back(t); } - BackpropInitialState result; + BackpropInitialState result; while (!tensor_stack.empty()) { int64 tensor_id = tensor_stack.back(); tensor_stack.pop_back(); @@ -383,9 +380,9 @@ BackpropInitialState PrepareBackprop( return result; } -template +template std::vector InitialStack( - const OpTape& op_tape, + const OpTape& op_tape, const gtl::FlatMap& op_missing_tensor) { std::vector result; for (auto& op_entry : op_tape) { @@ -396,13 +393,13 @@ std::vector InitialStack( return result; } -template -Status InitialGradients(const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice output_gradients, - const TensorTape& tensor_tape, - const OpTape& op_tape, - gtl::FlatMap>* result) { +template +Status InitialGradients( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, + const OpTape& op_tape, + gtl::FlatMap>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; if (output_gradients.empty() || output_gradients[i] == nullptr) { @@ -416,11 +413,10 @@ Status InitialGradients(const VSpace& vspace, } bool found = false; for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { - if (op_it->second.output_tensor_info[j].id == id) { + if (op_it->second.output_tensor_info[j].GetID() == id) { found = true; (*result)[id].push_back( - vspace.Ones(op_it->second.output_tensor_info[j].shape, - op_it->second.output_tensor_info[j].dtype)); + vspace.Ones(op_it->second.output_tensor_info[j])); break; } } @@ -469,16 +465,16 @@ gtl::FlatMap>* FunctionsAcceptingNoneForIndicesMap() { constexpr int kMinAggregateCount = 4; constexpr int kMinAggregateBytes = 128 * 1024 * 1024; -template -Status GradientTape::ComputeGradient( - const VSpace& vspace, +template +Status GradientTape::ComputeGradient( + const VSpace& vspace, gtl::ArraySlice target_tensor_ids, gtl::ArraySlice source_tensor_ids, gtl::ArraySlice output_gradients, std::vector* result) { gtl::FlatSet sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); - BackpropInitialState state = PrepareBackprop( + BackpropInitialState state = PrepareBackprop( target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); @@ -522,7 +518,7 @@ Status GradientTape::ComputeGradient( out_gradients.reserve(trace.output_tensor_info.size()); bool any_gradient_nonzero = false; for (int i = 0; i < trace.output_tensor_info.size(); ++i) { - const int64 id = trace.output_tensor_info[i].id; + const int64 id = trace.output_tensor_info[i].GetID(); auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { auto func_name_it = @@ -531,9 +527,7 @@ Status GradientTape::ComputeGradient( func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { - out_gradients.push_back( - vspace.Zeros(trace.output_tensor_info[i].shape, - trace.output_tensor_info[i].dtype)); + out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i])); } } else { any_gradient_nonzero = true; -- cgit v1.2.3