aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-09-28 15:14:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 15:18:31 -0700
commit2f559f2d5f75cf80183ae0d855110809404019f7 (patch)
treed8101da5ee6b7008fa48cee7e5000c3bb183e6ee /tensorflow/core/framework
parentdee0481c07ed952d01b12704c89e50869a383c68 (diff)
Handle noinline gradient function in control flow functionalization.
PiperOrigin-RevId: 215003704
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/function.cc8
-rw-r--r--tensorflow/core/framework/function.h5
2 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index a17959a448..20f957190b 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1101,6 +1101,14 @@ Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
return Status::OK();
}
+Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
+ mutex_lock l(mu_);
+ bool added;
+ TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
+ TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
+ return Status::OK();
+}
+
Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
const auto& i = function_defs_.find(func);
if (i == function_defs_.end()) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e01eb7503d..4d6d68e214 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -331,6 +331,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// a non-OK status if "func" was not found in the library, OK otherwise.
Status ReplaceFunction(const string& func, const FunctionDef& fdef);
+ // Replaces the gradient corresponding to `grad.function_name()`. Returns
+ // a non-OK status if "grad.function_name()" was not found in the library, OK
+ // otherwise.
+ Status ReplaceGradient(const GradientDef& grad);
+
// Adds the functions and gradients in 'other' to this function library.
// Duplicate functions and gradients are ignored.
// This operation is atomic.