diff options
author | Asim Shankar <ashankar@google.com> | 2018-05-10 09:38:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-10 09:41:04 -0700 |
commit | e696dc1bd07f62c6621a7224e15c8d3fbc160054 (patch) | |
tree | 053a25637e965762bca20ac8496e20a0b6ed343d /tensorflow/c/eager | |
parent | 4522626aff528815bc4087ab5b43c88b2d17a832 (diff) |
Automated g4 rollback of changelist 195878952
PiperOrigin-RevId: 196127751
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/tape.h | 36 |
1 files changed, 7 insertions, 29 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index e9ed3395c4..8026076b9e 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -130,15 +130,13 @@ class GradientTape { } } - bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids, - gtl::ArraySlice<tensorflow::DataType> dtypes); + bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids); void Watch(int64 tensor_id); void RecordOperation(const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, gtl::ArraySlice<int64> input_tensor_id, - gtl::ArraySlice<tensorflow::DataType> input_dtypes, BackwardFunction* backward_function, const std::function<void()>& backward_function_deleter); @@ -172,30 +170,12 @@ class GradientTape { // Template instantiations here -inline bool IsDtypeTrainable(DataType dtype) { - switch (dtype) { - case DT_HALF: - case DT_BFLOAT16: - case DT_FLOAT: - case DT_DOUBLE: - case DT_COMPLEX64: - case DT_COMPLEX128: - case DT_RESOURCE: - case DT_VARIANT: - return true; - default: - return false; - } -} - template <typename Gradient, typename BackwardFunction> bool GradientTape<Gradient, BackwardFunction>::ShouldRecord( - gtl::ArraySlice<int64> tensor_ids, - gtl::ArraySlice<tensorflow::DataType> dtypes) { - CHECK_EQ(tensor_ids.size(), dtypes.size()); - for (int i = 0; i < tensor_ids.size(); ++i) { - if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { - return IsDtypeTrainable(dtypes[i]); + gtl::ArraySlice<int64> tensor_ids) { + for (int64 i : tensor_ids) { + if (tensor_tape_.find(i) != tensor_tape_.end()) { + return true; } } return false; @@ -209,11 +189,9 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) { template <typename Gradient, typename BackwardFunction> void GradientTape<Gradient, BackwardFunction>::RecordOperation( const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, - gtl::ArraySlice<int64> input_tensor_id, - gtl::ArraySlice<tensorflow::DataType> input_dtypes, - BackwardFunction* backward_function, + gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function, const std::function<void()>& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id, input_dtypes)) { + if (!ShouldRecord(input_tensor_id)) { backward_function_deleter(); return; } |