diff options
author | 2018-05-08 14:42:35 -0700 | |
---|---|---|
committer | 2018-05-08 15:52:59 -0700 | |
commit | f58effe44dea9e8c7bf092c6779cd430994f7a72 (patch) | |
tree | 97414aef95a5b2e05e538d90227151316dc0244e /tensorflow/c/eager | |
parent | d3f3fb5b5f2db18f890838b29cac94ba88335f0a (diff) |
Do not differentiage integers in the eager API.
This is similar to the change made in:
https://github.com/tensorflow/tensorflow/commit/f63750645826df65b05cad505546a86f0e347674
for backpropagation during graph construction via tf.gradients()
PiperOrigin-RevId: 195878952
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/tape.h | 36 |
1 files changed, 29 insertions, 7 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 8026076b9e..e9ed3395c4 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -130,13 +130,15 @@ class GradientTape { } } - bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids); + bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids, + gtl::ArraySlice<tensorflow::DataType> dtypes); void Watch(int64 tensor_id); void RecordOperation(const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, gtl::ArraySlice<int64> input_tensor_id, + gtl::ArraySlice<tensorflow::DataType> input_dtypes, BackwardFunction* backward_function, const std::function<void()>& backward_function_deleter); @@ -170,12 +172,30 @@ class GradientTape { // Template instantiations here +inline bool IsDtypeTrainable(DataType dtype) { + switch (dtype) { + case DT_HALF: + case DT_BFLOAT16: + case DT_FLOAT: + case DT_DOUBLE: + case DT_COMPLEX64: + case DT_COMPLEX128: + case DT_RESOURCE: + case DT_VARIANT: + return true; + default: + return false; + } +} + template <typename Gradient, typename BackwardFunction> bool GradientTape<Gradient, BackwardFunction>::ShouldRecord( - gtl::ArraySlice<int64> tensor_ids) { - for (int64 i : tensor_ids) { - if (tensor_tape_.find(i) != tensor_tape_.end()) { - return true; + gtl::ArraySlice<int64> tensor_ids, + gtl::ArraySlice<tensorflow::DataType> dtypes) { + CHECK_EQ(tensor_ids.size(), dtypes.size()); + for (int i = 0; i < tensor_ids.size(); ++i) { + if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { + return IsDtypeTrainable(dtypes[i]); } } return false; @@ -189,9 +209,11 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) { template <typename Gradient, typename BackwardFunction> void GradientTape<Gradient, BackwardFunction>::RecordOperation( const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, - gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function, + gtl::ArraySlice<int64> input_tensor_id, + gtl::ArraySlice<tensorflow::DataType> input_dtypes, + BackwardFunction* backward_function, const std::function<void()>& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id)) { + if (!ShouldRecord(input_tensor_id, input_dtypes)) { backward_function_deleter(); return; } |