aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-08-08 10:48:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 10:51:51 -0700
commit4a4ae62c75f1de3455c3adea96802d22c7e986e3 (patch)
treeb499f2cb33fc2aaac2963afa9c12be2bfb8cd37a /tensorflow/c
parent826e190e053abcf722cd33084ef2d31b6ed1b2aa (diff)
Allows differentiating tfe.defun functions with loops in eager mode.
Adopts a minimal sensible policy for step containers: starting a graident tape creates a step container; inner tapes do nothing; popping out of the outermost tape will reset that step container. This should allow us to have reasonable behavior in the presence of step-container-scoped things for a while. Ideally we'll move away from them in favor of lists but the infrastructure isn't ready yet. PiperOrigin-RevId: 207911091
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/eager/c_api.cc4
-rw-r--r--tensorflow/c/eager/c_api.h10
2 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index a0a44440c8..d7073d8e05 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -719,6 +719,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
}
} // namespace
+void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
+
+void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
+
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 25cf7adbc7..092af45731 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -380,6 +380,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
+// Some TF ops need a step container to be set to limit the lifetime of some
+// resources (mostly TensorArray and Stack, used in while loop gradients in
+// graph mode). Calling this on a context tells it to start a step.
+TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx);
+
+// Ends a step. When there is no active step (that is, every started step has
+// been ended) step containers will be cleared. Note: it is not safe to call
+// TFE_ContextEndStep while ops which rely on the step container may be running.
+TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif