diff options
author | Igor Ganichev <iga@google.com> | 2017-09-18 20:19:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-18 20:22:15 -0700 |
commit | 1da763a1cc94fc5e4ad1822788b444b77623538c (patch) | |
tree | 414b168e2f3b1dc8be71c58b71e36c34818ff1c6 /tensorflow/c/c_api_function.cc | |
parent | d10902f0a947da40f80479d74e9a487617759085 (diff) |
Add function gradient support to C API
Also, change the internal representation of TF_Function and
rename TF_GraphAddFunction to TF_GraphAddFunctionCopy to make it
clear that a copy of the function is added to the graph. Any
subsequent modifications to the function will not be reflected
in the copy added to the graph.
PiperOrigin-RevId: 169187793
Diffstat (limited to 'tensorflow/c/c_api_function.cc')
-rw-r--r-- | tensorflow/c/c_api_function.cc | 60 |
1 files changed, 36 insertions, 24 deletions
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index b4c6397d0b..739d5ce986 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/strings/strcat.h" +using tensorflow::errors::InvalidArgument; + namespace tensorflow { namespace { @@ -125,10 +127,10 @@ Status ValidateNoRefOutputs(const Node* node) { for (int i = 0; i < node->num_outputs(); ++i) { const DataType& dt = node->output_type(i); if (IsRefType(dt)) { - return errors::InvalidArgument("Output ", i, " of node '", node->name(), - "' has a reference " - "type ", - DataTypeString(dt)); + return InvalidArgument("Output ", i, " of node '", node->name(), + "' has a reference " + "type ", + DataTypeString(dt)); } } return Status::OK(); @@ -178,7 +180,7 @@ Status FillFunctionBody( // A backedge might not appear as a regular Edge, but be only present // in the node_def. Such edges are referred to as requested_inputs(). if (i >= node->requested_inputs().size()) { - return errors::InvalidArgument( + return InvalidArgument( "Graph to be converted to function appears to be malformed. ", "Node ", node->name(), " is missing input edge ", i); } @@ -191,7 +193,7 @@ Status FillFunctionBody( const auto iter = tensor_renaming.find(original_input_name); if (iter == tensor_renaming.end()) { - return errors::InvalidArgument( + return InvalidArgument( "Input ", i, ", '", original_input_name, "', of node '", node->name(), "' in function '", fn_name, "' is not available. You might need to include it in inputs " @@ -207,7 +209,7 @@ Status FillFunctionBody( // If we did not find a name for the source of control edge, this // source must be outside of the body. Raise an error. if (normalized.empty()) { - return errors::InvalidArgument( + return InvalidArgument( "The source of control edge ", edge->DebugString(), " is not in the body. Encountered while creating function '", fn_name, "'"); @@ -308,7 +310,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, strings::StrCat(outputs[r].node->name(), ":", outputs[r].index); const auto iter = tensor_renaming.find(return_value); if (iter == tensor_renaming.end()) { - return errors::InvalidArgument( + return InvalidArgument( "TF_Output ", return_value, " is neither in the function body ", "nor among function inputs. Encountered while creating function '", fn_name, "'"); @@ -349,9 +351,8 @@ Status ProcessInputs( } else { auto& indices = iter->second; if (std::find(indices.begin(), indices.end(), idx) != indices.end()) { - return errors::InvalidArgument( - "TF_Output ", node.name(), ":", idx, - " appears more than once in the input list"); + return InvalidArgument("TF_Output ", node.name(), ":", idx, + " appears more than once in the input list"); } indices.push_back(idx); } @@ -400,7 +401,7 @@ Status ComputeBodyNodes( // artificial restriction and require that when num_opers=-1, such // nodes must have a single output. if (node->num_outputs() != 1) { - return errors::InvalidArgument( + return InvalidArgument( "When `num_opers` is set to -1, nodes referenced in `inputs` " "must have a single output. Node ", node->name(), " has ", node->num_outputs(), @@ -468,7 +469,7 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, TF_Function* tf_function = new TF_Function(); status->status = tensorflow::GraphToFunctionDef( fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors, - output_names_vec, tf_function->fdef_lib.add_function()); + output_names_vec, &tf_function->fdef); if (!status->status.ok()) { TF_DeleteFunction(tf_function); return nullptr; @@ -476,21 +477,32 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, return tf_function; } -void TF_GraphAddFunction(TF_Graph* g, const TF_Function* function, - TF_Status* status) { - tensorflow::mutex_lock l(g->mu); +void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, + const TF_Function* grad, TF_Status* status) { + if (func == nullptr) { + status->status = InvalidArgument( + "'func' argument to TF_GraphCopyFunction cannot be null"); + return; + } - // At the moment, we have only one function and no gradients in fdef_lib. - // This makes the following operation atomic. - // TODO(iga): Add an atomic version of AddFunctionLibrary when we support - // gradients - status->status = g->graph.AddFunctionLibrary(function->fdef_lib); + // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph + // to avoid the extra copy here. + tensorflow::FunctionDefLibrary fdef_lib; + *fdef_lib.add_function() = func->fdef; + if (grad) { + *fdef_lib.add_function() = grad->fdef; + tensorflow::GradientDef* gdef = fdef_lib.add_gradient(); + gdef->set_function_name(func->fdef.signature().name()); + gdef->set_gradient_func(grad->fdef.signature().name()); + } + + tensorflow::mutex_lock l(g->mu); + status->status = g->graph.AddFunctionLibrary(fdef_lib); } void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def, TF_Status* status) { - DCHECK_EQ(1, func->fdef_lib.function_size()); - status->status = MessageToBuffer(func->fdef_lib.function(0), output_func_def); + status->status = MessageToBuffer(func->fdef, output_func_def); } -void TF_DeleteFunction(TF_Function* function) { delete function; } +void TF_DeleteFunction(TF_Function* func) { delete func; } |