aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-10-01 15:10:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 15:16:07 -0700
commitc7237e6070dbf4acd1ade5a40dc676418cbd889b (patch)
tree9671709249d76f83b273192e9d729d20d8135910 /tensorflow/c
parent9084e999b3caf65833f9651c72bc09eb3094eba5 (diff)
Don't generate backward function and delete when its not necessary
PiperOrigin-RevId: 215288224
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/eager/tape.h7
1 files changed, 3 insertions, 4 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 41b5b8ff36..5ba55a203f 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -130,7 +130,7 @@ class GradientTape {
const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
- BackwardFunction* backward_function,
+ const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter);
void DeleteTrace(int64 tensor_id);
@@ -206,10 +206,9 @@ void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
- BackwardFunction* backward_function,
+ const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
- backward_function_deleter(backward_function);
return;
}
std::vector<int64> ids;
@@ -229,7 +228,7 @@ void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
tensors.push_back(o);
}
op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
- op_type, std::move(tensors), ids, backward_function,
+ op_type, std::move(tensors), std::move(ids), backward_function_getter(),
backward_function_deleter};
}