aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-09-27 13:28:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-27 13:32:17 -0700
commit301b14c240fe99249dc2225132a7ebe5cbecbdc4 (patch)
tree8de10cf0180ff41211b294c05cae199208487c85 /tensorflow/cc/ops
parent545e3572f7d8928eeb220e8b55c71ad33a9343c6 (diff)
Basic while loop gradient functionality in C++
This change introduces the basic framework to create the gradient graph of a while loop using the C++ API. This supports building the gradient graph as long as the body function of the while loop contains no ops whose gradient function requires a stack. In other words, it doesn't support gradient functions that use the input values to the op (e.g. add will work, but multiply will not). It also doesn't support nested while loops, and doesn't detect all error cases. PiperOrigin-RevId: 170243281
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r--tensorflow/cc/ops/while_loop.h7
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h
index 82181516d6..a04476056a 100644
--- a/tensorflow/cc/ops/while_loop.h
+++ b/tensorflow/cc/ops/while_loop.h
@@ -49,7 +49,12 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
// * outputs: output param that returns final loop variable outputs in non-error
// case. Must be non-null and empty.
// * create_while_ctx: if true, a WhileContext is created and populated for this
-// loop. See core/graph/while_context.h for more details.
+// loop. See core/graph/while_context.h for more details on
+// WhileContexts. This is set to false for loops used as part of gradient
+// computations, since they're part of the gradient for a loop in the
+// forward-pass.
+// TODO(skyewm): revisit this. Should we create WhileContexts for all loops,
+// even if we don't need them?
// * cond_output: if non-null, the output of the predicate is returned. This
// will always be a LoopCond node.
//