aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-05-24 16:20:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 16:23:23 -0700
commitc0dd400f43cf6335165ce772e290a13b50960b23 (patch)
tree2bd0e0fab97f14c4ace8391a9f8d82fe4fe19686 /tensorflow/c/eager
parent50d66adc550cc9bcd2337cd28d1561273db43de9 (diff)
Remove _get_backward_fn and depend on _gradient_function directly.
(_magic_gradient_function was renamed to _gradient_function) Before: entry { name: "MicroBenchmarks.benchmark_tf_gradient_forward_identity" iters: 30000 wall_time: 5.88456789653 extras { key: "examples_per_sec" value { double_value: 169936.011885 } } } After: entry { name: "MicroBenchmarks.benchmark_tf_gradient_forward_identity" iters: 30000 wall_time: 5.04853725433 extras { key: "examples_per_sec" value { double_value: 198077.175551 } } } PiperOrigin-RevId: 197972668
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/tape.h37
1 files changed, 16 insertions, 21 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 1833b25fea..734e712daa 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -48,7 +48,7 @@ struct OpTapeEntry {
// Should be called before deleting the backward function. TODO(apassos) use
// unique_ptrs to ensure this happens.
- std::function<void()> backward_function_deleter;
+ std::function<void(BackwardFunction*)> backward_function_deleter;
};
// Map from tensor_id to internally-defined operation-id of the operation which
@@ -110,12 +110,6 @@ class VSpace {
// Deletes the input tensor.
virtual void DeleteGradient(Gradient* gradient) const = 0;
-
- // Lets this VSpace know that it can release resources held by the
- // `backward_function`, It will not be called again.
- // `backward_function` must not be null.
- virtual void ReleaseBackwardFunction(
- BackwardFunction* backward_function) const = 0;
};
// Traces the execution of operations, doing eager garbage collection, and
@@ -130,7 +124,7 @@ class GradientTape {
GradientTape(bool persistent) : persistent_(persistent) {}
~GradientTape() {
for (const auto& pair : op_tape_) {
- pair.second.backward_function_deleter();
+ pair.second.backward_function_deleter(pair.second.backward_function);
}
}
@@ -139,12 +133,12 @@ class GradientTape {
void Watch(int64 tensor_id);
- void RecordOperation(const string& op_type,
- gtl::ArraySlice<TapeTensor> output_tensors,
- gtl::ArraySlice<int64> input_tensor_id,
- gtl::ArraySlice<tensorflow::DataType> input_dtypes,
- BackwardFunction* backward_function,
- const std::function<void()>& backward_function_deleter);
+ void RecordOperation(
+ const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+ gtl::ArraySlice<int64> input_tensor_id,
+ gtl::ArraySlice<tensorflow::DataType> input_dtypes,
+ BackwardFunction* backward_function,
+ const std::function<void(BackwardFunction*)>& backward_function_deleter);
void DeleteTrace(int64 tensor_id);
@@ -218,9 +212,9 @@ void GradientTape<Gradient, BackwardFunction>::RecordOperation(
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
- const std::function<void()>& backward_function_deleter) {
+ const std::function<void(BackwardFunction*)>& backward_function_deleter) {
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
- backward_function_deleter();
+ backward_function_deleter(backward_function);
return;
}
std::vector<int64> ids;
@@ -275,7 +269,7 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
for (int64 id : op_it->second.input_tensor_id) {
DeleteTrace(id);
}
- op_it->second.backward_function_deleter();
+ op_it->second.backward_function_deleter(op_it->second.backward_function);
op_tape_.erase(op_it);
}
@@ -381,7 +375,8 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
// backward functions that will be used for gradient computation
// has been transferred to `result`.
for (const auto& op_pair : *op_tape) {
- op_pair.second.backward_function_deleter();
+ op_pair.second.backward_function_deleter(
+ op_pair.second.backward_function);
}
op_tape->clear();
}
@@ -473,7 +468,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
if (!persistent_) {
// Release all backprop functions
for (const auto& pair : state.op_tape) {
- pair.second.backward_function_deleter();
+ pair.second.backward_function_deleter(pair.second.backward_function);
}
}
};
@@ -541,7 +536,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
Status s = vspace.CallBackwardFunction(trace.backward_function,
out_gradients, &in_gradients);
if (!persistent_) {
- vspace.ReleaseBackwardFunction(trace.backward_function);
+ trace.backward_function_deleter(trace.backward_function);
}
if (!s.ok()) {
cleanup();
@@ -550,7 +545,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
} else {
in_gradients.resize(trace.input_tensor_id.size());
if (!persistent_) {
- vspace.ReleaseBackwardFunction(trace.backward_function);
+ trace.backward_function_deleter(trace.backward_function);
}
for (Gradient* grad : out_gradients) {
if (grad != nullptr) {