aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-05-08 14:42:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-08 15:52:59 -0700
commitf58effe44dea9e8c7bf092c6779cd430994f7a72 (patch)
tree97414aef95a5b2e05e538d90227151316dc0244e /tensorflow/c/eager
parentd3f3fb5b5f2db18f890838b29cac94ba88335f0a (diff)
Do not differentiage integers in the eager API.
This is similar to the change made in: https://github.com/tensorflow/tensorflow/commit/f63750645826df65b05cad505546a86f0e347674 for backpropagation during graph construction via tf.gradients() PiperOrigin-RevId: 195878952
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/tape.h36
1 files changed, 29 insertions, 7 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 8026076b9e..e9ed3395c4 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -130,13 +130,15 @@ class GradientTape {
}
}
- bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
+ bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
+ gtl::ArraySlice<tensorflow::DataType> dtypes);
void Watch(int64 tensor_id);
void RecordOperation(const string& op_type,
gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
+ gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter);
@@ -170,12 +172,30 @@ class GradientTape {
// Template instantiations here
+inline bool IsDtypeTrainable(DataType dtype) {
+ switch (dtype) {
+ case DT_HALF:
+ case DT_BFLOAT16:
+ case DT_FLOAT:
+ case DT_DOUBLE:
+ case DT_COMPLEX64:
+ case DT_COMPLEX128:
+ case DT_RESOURCE:
+ case DT_VARIANT:
+ return true;
+ default:
+ return false;
+ }
+}
+
template <typename Gradient, typename BackwardFunction>
bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
- gtl::ArraySlice<int64> tensor_ids) {
- for (int64 i : tensor_ids) {
- if (tensor_tape_.find(i) != tensor_tape_.end()) {
- return true;
+ gtl::ArraySlice<int64> tensor_ids,
+ gtl::ArraySlice<tensorflow::DataType> dtypes) {
+ CHECK_EQ(tensor_ids.size(), dtypes.size());
+ for (int i = 0; i < tensor_ids.size(); ++i) {
+ if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
+ return IsDtypeTrainable(dtypes[i]);
}
}
return false;
@@ -189,9 +209,11 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
template <typename Gradient, typename BackwardFunction>
void GradientTape<Gradient, BackwardFunction>::RecordOperation(
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
- gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
+ gtl::ArraySlice<int64> input_tensor_id,
+ gtl::ArraySlice<tensorflow::DataType> input_dtypes,
+ BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter) {
- if (!ShouldRecord(input_tensor_id)) {
+ if (!ShouldRecord(input_tensor_id, input_dtypes)) {
backward_function_deleter();
return;
}