aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/function.cc')
-rw-r--r--tensorflow/core/common_runtime/function.cc35
1 files changed, 21 insertions, 14 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 403cece230..47bd6c56ec 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1170,37 +1170,44 @@ FunctionBody* SymbolicGradientHelper::Compute() {
Copy();
Graph* g = gbody_->graph;
+
+ const int num_y = gbody_->ret_nodes.size();
+
+ // Populate 'y_node_outputs_' with node function body outputs.
// Populate 'y_grad_nodes' with initial gradient nodes for each return node of
// the original function body (these will be 'arg' nodes in the function
// gradient body).
- const int num_y = gbody_->ret_nodes.size();
- std::vector<Node*> y_grad_nodes;
- y_grad_nodes.reserve(num_y);
+ std::vector<NodeOut> y_node_outputs;
+ y_node_outputs.reserve(num_y);
+ std::vector<NodeOut> y_grad_node_outputs;
+ y_grad_node_outputs.reserve(num_y);
for (int i = 0; i < num_y; ++i) {
Node* y = gbody_->ret_nodes[i];
+ y_node_outputs.push_back({y, 0});
DCHECK_EQ(y->type_string(), kRetOp);
const DataType dtype = y->input_type(0);
const int index = gbody_->arg_nodes.size();
Node* dy = AddArg(g, dtype, index);
gbody_->arg_types.push_back(dtype);
gbody_->arg_nodes.push_back(dy);
- y_grad_nodes.push_back(dy);
+ y_grad_node_outputs.push_back({dy, 0});
}
- // Populate 'x_nodes' with function args (not including 'y_grad_nodes').
+ // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
const int num_x = fbody_->arg_nodes.size();
- std::vector<Node*> x_nodes;
- x_nodes.reserve(num_x);
+ std::vector<NodeOut> x_node_outputs;
+ x_node_outputs.reserve(num_x);
for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
- x_nodes.push_back(gbody_->arg_nodes[i]);
+ x_node_outputs.push_back({gbody_->arg_nodes[i], 0});
}
// Call AddSymbolicGradients which will add nodes to graph 'g' that
- // compute the function gradient (adding an entry in 'x_grad_nodes' for
- // each node in 'x_nodes').
- std::vector<GradNodeOutput> x_grad_nodes(x_nodes.size());
- TF_CHECK_OK(AddSymbolicGradients(gbody_->ret_nodes, x_nodes, y_grad_nodes,
- &x_grad_nodes, g));
+ // compute the function gradient (adding an entry in 'x_grad_node_outputs' for
+ // each node in 'x_node_outputs').
+ std::vector<NodeOut> x_grad_node_outputs;
+ TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
+ y_grad_node_outputs, &x_grad_node_outputs,
+ g));
// Remove the old return nodes from the function body.
for (Node* n : gbody_->ret_nodes) {
@@ -1211,7 +1218,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
// Add new return nodes to the function gradient body for each node
// in 'x_grad_nodes'.
for (size_t i = 0; i < fbody_->arg_types.size(); ++i) {
- Endpoint grad = {x_grad_nodes[i].node, x_grad_nodes[i].index};
+ Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
Node* ret = AddRet(g, grad, i);
gbody_->ret_nodes.push_back(ret);
}