diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-05-24 16:20:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-24 16:23:23 -0700 |
commit | c0dd400f43cf6335165ce772e290a13b50960b23 (patch) | |
tree | 2bd0e0fab97f14c4ace8391a9f8d82fe4fe19686 /tensorflow/c/eager | |
parent | 50d66adc550cc9bcd2337cd28d1561273db43de9 (diff) |
Remove _get_backward_fn and depend on _gradient_function directly.
(_magic_gradient_function was renamed to _gradient_function)
Before:
entry {
name: "MicroBenchmarks.benchmark_tf_gradient_forward_identity"
iters: 30000
wall_time: 5.88456789653
extras {
key: "examples_per_sec"
value {
double_value: 169936.011885
}
}
}
After:
entry {
name: "MicroBenchmarks.benchmark_tf_gradient_forward_identity"
iters: 30000
wall_time: 5.04853725433
extras {
key: "examples_per_sec"
value {
double_value: 198077.175551
}
}
}
PiperOrigin-RevId: 197972668
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/tape.h | 37 |
1 files changed, 16 insertions, 21 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 1833b25fea..734e712daa 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -48,7 +48,7 @@ struct OpTapeEntry { // Should be called before deleting the backward function. TODO(apassos) use // unique_ptrs to ensure this happens. - std::function<void()> backward_function_deleter; + std::function<void(BackwardFunction*)> backward_function_deleter; }; // Map from tensor_id to internally-defined operation-id of the operation which @@ -110,12 +110,6 @@ class VSpace { // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; - - // Lets this VSpace know that it can release resources held by the - // `backward_function`, It will not be called again. - // `backward_function` must not be null. - virtual void ReleaseBackwardFunction( - BackwardFunction* backward_function) const = 0; }; // Traces the execution of operations, doing eager garbage collection, and @@ -130,7 +124,7 @@ class GradientTape { GradientTape(bool persistent) : persistent_(persistent) {} ~GradientTape() { for (const auto& pair : op_tape_) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } @@ -139,12 +133,12 @@ class GradientTape { void Watch(int64 tensor_id); - void RecordOperation(const string& op_type, - gtl::ArraySlice<TapeTensor> output_tensors, - gtl::ArraySlice<int64> input_tensor_id, - gtl::ArraySlice<tensorflow::DataType> input_dtypes, - BackwardFunction* backward_function, - const std::function<void()>& backward_function_deleter); + void RecordOperation( + const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, + gtl::ArraySlice<int64> input_tensor_id, + gtl::ArraySlice<tensorflow::DataType> input_dtypes, + BackwardFunction* backward_function, + const std::function<void(BackwardFunction*)>& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -218,9 +212,9 @@ void GradientTape<Gradient, BackwardFunction>::RecordOperation( gtl::ArraySlice<int64> input_tensor_id, gtl::ArraySlice<tensorflow::DataType> input_dtypes, BackwardFunction* backward_function, - const std::function<void()>& backward_function_deleter) { + const std::function<void(BackwardFunction*)>& backward_function_deleter) { if (!ShouldRecord(input_tensor_id, input_dtypes)) { - backward_function_deleter(); + backward_function_deleter(backward_function); return; } std::vector<int64> ids; @@ -275,7 +269,7 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) { for (int64 id : op_it->second.input_tensor_id) { DeleteTrace(id); } - op_it->second.backward_function_deleter(); + op_it->second.backward_function_deleter(op_it->second.backward_function); op_tape_.erase(op_it); } @@ -381,7 +375,8 @@ BackpropInitialState<BackwardFunction> PrepareBackprop( // backward functions that will be used for gradient computation // has been transferred to `result`. for (const auto& op_pair : *op_tape) { - op_pair.second.backward_function_deleter(); + op_pair.second.backward_function_deleter( + op_pair.second.backward_function); } op_tape->clear(); } @@ -473,7 +468,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( if (!persistent_) { // Release all backprop functions for (const auto& pair : state.op_tape) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } }; @@ -541,7 +536,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( Status s = vspace.CallBackwardFunction(trace.backward_function, out_gradients, &in_gradients); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } if (!s.ok()) { cleanup(); @@ -550,7 +545,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( } else { in_gradients.resize(trace.input_tensor_id.size()); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } for (Gradient* grad : out_gradients) { if (grad != nullptr) { |