diff options
author | Tong Shen <endlessroad@google.com> | 2018-09-28 15:14:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 15:18:31 -0700 |
commit | 2f559f2d5f75cf80183ae0d855110809404019f7 (patch) | |
tree | d8101da5ee6b7008fa48cee7e5000c3bb183e6ee /tensorflow/core/framework | |
parent | dee0481c07ed952d01b12704c89e50869a383c68 (diff) |
Handle noinline gradient function in control flow functionalization.
PiperOrigin-RevId: 215003704
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/function.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 5 |
2 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index a17959a448..20f957190b 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1101,6 +1101,14 @@ Status FunctionLibraryDefinition::ReplaceFunction(const string& func, return Status::OK(); } +Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) { + mutex_lock l(mu_); + bool added; + TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name())); + TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added)); + return Status::OK(); +} + Status FunctionLibraryDefinition::RemoveFunction(const string& func) { const auto& i = function_defs_.find(func); if (i == function_defs_.end()) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index e01eb7503d..4d6d68e214 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -331,6 +331,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // a non-OK status if "func" was not found in the library, OK otherwise. Status ReplaceFunction(const string& func, const FunctionDef& fdef); + // Replaces the gradient corresponding to `grad.function_name()`. Returns + // a non-OK status if "grad.function_name()" was not found in the library, OK + // otherwise. + Status ReplaceGradient(const GradientDef& grad); + // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. |