aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_function.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-09-18 20:19:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-18 20:22:15 -0700
commit1da763a1cc94fc5e4ad1822788b444b77623538c (patch)
tree414b168e2f3b1dc8be71c58b71e36c34818ff1c6 /tensorflow/c/c_api_function.cc
parentd10902f0a947da40f80479d74e9a487617759085 (diff)
Add function gradient support to C API
Also, change the internal representation of TF_Function and rename TF_GraphAddFunction to TF_GraphAddFunctionCopy to make it clear that a copy of the function is added to the graph. Any subsequent modifications to the function will not be reflected in the copy added to the graph. PiperOrigin-RevId: 169187793
Diffstat (limited to 'tensorflow/c/c_api_function.cc')
-rw-r--r--tensorflow/c/c_api_function.cc60
1 files changed, 36 insertions, 24 deletions
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index b4c6397d0b..739d5ce986 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/strings/strcat.h"
+using tensorflow::errors::InvalidArgument;
+
namespace tensorflow {
namespace {
@@ -125,10 +127,10 @@ Status ValidateNoRefOutputs(const Node* node) {
for (int i = 0; i < node->num_outputs(); ++i) {
const DataType& dt = node->output_type(i);
if (IsRefType(dt)) {
- return errors::InvalidArgument("Output ", i, " of node '", node->name(),
- "' has a reference "
- "type ",
- DataTypeString(dt));
+ return InvalidArgument("Output ", i, " of node '", node->name(),
+ "' has a reference "
+ "type ",
+ DataTypeString(dt));
}
}
return Status::OK();
@@ -178,7 +180,7 @@ Status FillFunctionBody(
// A backedge might not appear as a regular Edge, but be only present
// in the node_def. Such edges are referred to as requested_inputs().
if (i >= node->requested_inputs().size()) {
- return errors::InvalidArgument(
+ return InvalidArgument(
"Graph to be converted to function appears to be malformed. ",
"Node ", node->name(), " is missing input edge ", i);
}
@@ -191,7 +193,7 @@ Status FillFunctionBody(
const auto iter = tensor_renaming.find(original_input_name);
if (iter == tensor_renaming.end()) {
- return errors::InvalidArgument(
+ return InvalidArgument(
"Input ", i, ", '", original_input_name, "', of node '",
node->name(), "' in function '", fn_name,
"' is not available. You might need to include it in inputs "
@@ -207,7 +209,7 @@ Status FillFunctionBody(
// If we did not find a name for the source of control edge, this
// source must be outside of the body. Raise an error.
if (normalized.empty()) {
- return errors::InvalidArgument(
+ return InvalidArgument(
"The source of control edge ", edge->DebugString(),
" is not in the body. Encountered while creating function '",
fn_name, "'");
@@ -308,7 +310,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
const auto iter = tensor_renaming.find(return_value);
if (iter == tensor_renaming.end()) {
- return errors::InvalidArgument(
+ return InvalidArgument(
"TF_Output ", return_value, " is neither in the function body ",
"nor among function inputs. Encountered while creating function '",
fn_name, "'");
@@ -349,9 +351,8 @@ Status ProcessInputs(
} else {
auto& indices = iter->second;
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
- return errors::InvalidArgument(
- "TF_Output ", node.name(), ":", idx,
- " appears more than once in the input list");
+ return InvalidArgument("TF_Output ", node.name(), ":", idx,
+ " appears more than once in the input list");
}
indices.push_back(idx);
}
@@ -400,7 +401,7 @@ Status ComputeBodyNodes(
// artificial restriction and require that when num_opers=-1, such
// nodes must have a single output.
if (node->num_outputs() != 1) {
- return errors::InvalidArgument(
+ return InvalidArgument(
"When `num_opers` is set to -1, nodes referenced in `inputs` "
"must have a single output. Node ",
node->name(), " has ", node->num_outputs(),
@@ -468,7 +469,7 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
TF_Function* tf_function = new TF_Function();
status->status = tensorflow::GraphToFunctionDef(
fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors,
- output_names_vec, tf_function->fdef_lib.add_function());
+ output_names_vec, &tf_function->fdef);
if (!status->status.ok()) {
TF_DeleteFunction(tf_function);
return nullptr;
@@ -476,21 +477,32 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
return tf_function;
}
-void TF_GraphAddFunction(TF_Graph* g, const TF_Function* function,
- TF_Status* status) {
- tensorflow::mutex_lock l(g->mu);
+void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
+ const TF_Function* grad, TF_Status* status) {
+ if (func == nullptr) {
+ status->status = InvalidArgument(
+ "'func' argument to TF_GraphCopyFunction cannot be null");
+ return;
+ }
- // At the moment, we have only one function and no gradients in fdef_lib.
- // This makes the following operation atomic.
- // TODO(iga): Add an atomic version of AddFunctionLibrary when we support
- // gradients
- status->status = g->graph.AddFunctionLibrary(function->fdef_lib);
+ // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
+ // to avoid the extra copy here.
+ tensorflow::FunctionDefLibrary fdef_lib;
+ *fdef_lib.add_function() = func->fdef;
+ if (grad) {
+ *fdef_lib.add_function() = grad->fdef;
+ tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
+ gdef->set_function_name(func->fdef.signature().name());
+ gdef->set_gradient_func(grad->fdef.signature().name());
+ }
+
+ tensorflow::mutex_lock l(g->mu);
+ status->status = g->graph.AddFunctionLibrary(fdef_lib);
}
void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
TF_Status* status) {
- DCHECK_EQ(1, func->fdef_lib.function_size());
- status->status = MessageToBuffer(func->fdef_lib.function(0), output_func_def);
+ status->status = MessageToBuffer(func->fdef, output_func_def);
}
-void TF_DeleteFunction(TF_Function* function) { delete function; }
+void TF_DeleteFunction(TF_Function* func) { delete func; }