diff options
Diffstat (limited to 'tensorflow/core/common_runtime/function.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 35 |
1 files changed, 21 insertions, 14 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 403cece230..47bd6c56ec 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -1170,37 +1170,44 @@ FunctionBody* SymbolicGradientHelper::Compute() { Copy(); Graph* g = gbody_->graph; + + const int num_y = gbody_->ret_nodes.size(); + + // Populate 'y_node_outputs_' with node function body outputs. // Populate 'y_grad_nodes' with initial gradient nodes for each return node of // the original function body (these will be 'arg' nodes in the function // gradient body). - const int num_y = gbody_->ret_nodes.size(); - std::vector<Node*> y_grad_nodes; - y_grad_nodes.reserve(num_y); + std::vector<NodeOut> y_node_outputs; + y_node_outputs.reserve(num_y); + std::vector<NodeOut> y_grad_node_outputs; + y_grad_node_outputs.reserve(num_y); for (int i = 0; i < num_y; ++i) { Node* y = gbody_->ret_nodes[i]; + y_node_outputs.push_back({y, 0}); DCHECK_EQ(y->type_string(), kRetOp); const DataType dtype = y->input_type(0); const int index = gbody_->arg_nodes.size(); Node* dy = AddArg(g, dtype, index); gbody_->arg_types.push_back(dtype); gbody_->arg_nodes.push_back(dy); - y_grad_nodes.push_back(dy); + y_grad_node_outputs.push_back({dy, 0}); } - // Populate 'x_nodes' with function args (not including 'y_grad_nodes'). + // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs'). const int num_x = fbody_->arg_nodes.size(); - std::vector<Node*> x_nodes; - x_nodes.reserve(num_x); + std::vector<NodeOut> x_node_outputs; + x_node_outputs.reserve(num_x); for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) { - x_nodes.push_back(gbody_->arg_nodes[i]); + x_node_outputs.push_back({gbody_->arg_nodes[i], 0}); } // Call AddSymbolicGradients which will add nodes to graph 'g' that - // compute the function gradient (adding an entry in 'x_grad_nodes' for - // each node in 'x_nodes'). - std::vector<GradNodeOutput> x_grad_nodes(x_nodes.size()); - TF_CHECK_OK(AddSymbolicGradients(gbody_->ret_nodes, x_nodes, y_grad_nodes, - &x_grad_nodes, g)); + // compute the function gradient (adding an entry in 'x_grad_node_outputs' for + // each node in 'x_node_outputs'). + std::vector<NodeOut> x_grad_node_outputs; + TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs, + y_grad_node_outputs, &x_grad_node_outputs, + g)); // Remove the old return nodes from the function body. for (Node* n : gbody_->ret_nodes) { @@ -1211,7 +1218,7 @@ FunctionBody* SymbolicGradientHelper::Compute() { // Add new return nodes to the function gradient body for each node // in 'x_grad_nodes'. for (size_t i = 0; i < fbody_->arg_types.size(); ++i) { - Endpoint grad = {x_grad_nodes[i].node, x_grad_nodes[i].index}; + Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index}; Node* ret = AddRet(g, grad, i); gbody_->ret_nodes.push_back(ret); } |