From c7237e6070dbf4acd1ade5a40dc676418cbd889b Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Mon, 1 Oct 2018 15:10:19 -0700 Subject: Don't generate backward function and delete when its not necessary PiperOrigin-RevId: 215288224 --- tensorflow/c/eager/tape.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'tensorflow/c') diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 41b5b8ff36..5ba55a203f 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -130,7 +130,7 @@ class GradientTape { const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, + const std::function& backward_function_getter, const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -206,10 +206,9 @@ void GradientTape::RecordOperation( const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, + const std::function& backward_function_getter, const std::function& backward_function_deleter) { if (!ShouldRecord(input_tensor_id, input_dtypes)) { - backward_function_deleter(backward_function); return; } std::vector ids; @@ -229,7 +228,7 @@ void GradientTape::RecordOperation( tensors.push_back(o); } op_tape_[op_id] = OpTapeEntry{ - op_type, std::move(tensors), ids, backward_function, + op_type, std::move(tensors), std::move(ids), backward_function_getter(), backward_function_deleter}; } -- cgit v1.2.3