aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_function.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-08-30 21:05:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 21:08:53 -0700
commit9624d165f1f2c717eda96464fee8bf7229cc14f5 (patch)
tree8024d708b58b0c78f19d4c3cfc9f7c4b0c24b70c /tensorflow/c/c_api_function.cc
parent424aa9aa9559f6fa29d8ccf3d74ff25528b39209 (diff)
Add function support to Tensorflow C API
This change adds minimal functionality. Support for FunctionOptions, attributes, output name rewriting, function name generation, etc is comming next. PiperOrigin-RevId: 167091238
Diffstat (limited to 'tensorflow/c/c_api_function.cc')
-rw-r--r--tensorflow/c/c_api_function.cc496
1 files changed, 496 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
new file mode 100644
index 0000000000..b4c6397d0b
--- /dev/null
+++ b/tensorflow/c/c_api_function.cc
@@ -0,0 +1,496 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/c_api_internal.h"
+
+#include <algorithm>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+// Class that maintains a one-to-one original node name -> new node name
+// mapping. We normalize the names used as input and output arguments to match
+// regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
+// Once we rename them, we risk creating a name collision with the other
+// node names, so if necessary we add a suffix to make
+// names unique. If we have an input named "A" and a node in the function
+// body named "a", they will be renamed to "a" and "a_0".
+class NodeNameMapping {
+ public:
+ NodeNameMapping() = default;
+
+ // Normalize the input/output name and make it unique.
+ string GetIOName(const string& name);
+
+ // Make the node name unique.
+ string Uniquify(const string& name);
+
+ // Look up how a node name was previously normalized/uniquified.
+ // Returns empty if name was never seen.
+ string Lookup(const string& name) const;
+
+ private:
+ string UniquifyHelper(const string& name) const;
+ static string Normalize(string name);
+
+ // The normalized/uniquified names already used as
+ // input names (in signature), output names (in signature), and node names
+ // (in node_def).
+ // This is a superset of values in name_mapping_.
+ std::unordered_set<string> used_names_;
+ // Mapping from original node name from the graph to the normalized
+ // and uniqified version of it.
+ std::unordered_map<string, string> name_mapping_;
+};
+
+string NodeNameMapping::Normalize(string name) {
+ // Convert letters to lowercase and non-alphanumeric characters to '_'.
+ if (name.empty()) return "unknown";
+ const int n = name.size();
+ for (int i = 0; i < n; ++i) {
+ char c = name[i];
+ if (isalnum(c)) {
+ if (isupper(c)) {
+ name[i] = tolower(c);
+ }
+ } else {
+ name[i] = '_';
+ }
+ }
+
+ // Find the first letter and start with it.
+ int i = 0;
+ for (; i < n; ++i) {
+ if (isalpha(name[i])) break;
+ }
+
+ // Return "unknown" if none of the name's chars were letters.
+ return i == n ? "unknown" : name.substr(i);
+}
+
+string NodeNameMapping::UniquifyHelper(const string& name) const {
+ // If the name hasn't been used yet, use it as-is.
+ if (used_names_.find(name) == used_names_.end()) return name;
+ // Add a suffix to name to make it unique.
+ for (int i = 0;; ++i) {
+ const string candidate = strings::StrCat(name, "_", i);
+ if (used_names_.find(candidate) == used_names_.end()) return candidate;
+ }
+}
+
+string NodeNameMapping::GetIOName(const string& name) {
+ const string& input_name = UniquifyHelper(Normalize(name));
+ // Record that we used this name, but don't add it to name_mapping_
+ // since this name is not for a node.
+ used_names_.insert(input_name);
+ return input_name;
+}
+
+string NodeNameMapping::Uniquify(const string& name) {
+ const string uniqued = UniquifyHelper(name);
+ name_mapping_[name] = uniqued;
+ used_names_.insert(uniqued);
+ return uniqued;
+}
+
+string NodeNameMapping::Lookup(const string& name) const {
+ const auto iter = name_mapping_.find(name);
+ if (iter == name_mapping_.end()) return string();
+ return iter->second;
+}
+
+Status ValidateNoRefOutputs(const Node* node) {
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ const DataType& dt = node->output_type(i);
+ if (IsRefType(dt)) {
+ return errors::InvalidArgument("Output ", i, " of node '", node->name(),
+ "' has a reference "
+ "type ",
+ DataTypeString(dt));
+ }
+ }
+ return Status::OK();
+}
+
+Status FillFunctionBody(
+ const string& fn_name, const NodeNameMapping& node_names,
+ const std::vector<const Node*>& body_nodes,
+ const std::unordered_map<string, string>& tensor_renaming,
+ FunctionDef* fdef) {
+ std::vector<const Edge*> in_edges;
+ std::vector<const Edge*> control_edges;
+ for (const Node* node : body_nodes) {
+ NodeDef* node_def = fdef->add_node_def();
+ // First, copy the node_def as is. We will patch it next.
+ *node_def = node->def();
+ if (!node->assigned_device_name().empty()) {
+ node_def->set_device(node->assigned_device_name());
+ }
+ node_def->set_name(node_names.Lookup(node->name()));
+
+ // Input names must be set based on nested names in tensor_renaming.
+ // Clear the flat input names we got from the original node_def
+ // from the graph.
+ node_def->clear_input();
+
+ // Collect regular and control inputs. Regular inputs are indexed
+ // by the index at which they come into the `node`. Control inputs
+ // don't follow any order.
+ in_edges.clear();
+ in_edges.resize(node->num_inputs(), nullptr);
+ control_edges.clear();
+ for (const Edge* edge : node->in_edges()) {
+ if (edge->src()->IsSource()) continue;
+ if (edge->IsControlEdge()) {
+ control_edges.push_back(edge);
+ } else {
+ in_edges[edge->dst_input()] = edge;
+ }
+ }
+
+ // Add regular inputs.
+ for (size_t i = 0; i < in_edges.size(); ++i) {
+ const Edge* edge = in_edges[i];
+ string original_input_name;
+ if (edge == nullptr) {
+ // A backedge might not appear as a regular Edge, but be only present
+ // in the node_def. Such edges are referred to as requested_inputs().
+ if (i >= node->requested_inputs().size()) {
+ return errors::InvalidArgument(
+ "Graph to be converted to function appears to be malformed. ",
+ "Node ", node->name(), " is missing input edge ", i);
+ }
+ original_input_name =
+ ParseTensorName(node->requested_inputs()[i]).ToString();
+ } else {
+ original_input_name =
+ strings::StrCat(edge->src()->name(), ":", edge->src_output());
+ }
+
+ const auto iter = tensor_renaming.find(original_input_name);
+ if (iter == tensor_renaming.end()) {
+ return errors::InvalidArgument(
+ "Input ", i, ", '", original_input_name, "', of node '",
+ node->name(), "' in function '", fn_name,
+ "' is not available. You might need to include it in inputs "
+ "or include its source node in the body");
+ }
+ node_def->add_input(iter->second);
+ }
+
+ // Add control inputs.
+ for (const Edge* edge : control_edges) {
+ // Add this control input only if the src node is in the body.
+ const string normalized = node_names.Lookup(edge->src()->name());
+ // If we did not find a name for the source of control edge, this
+ // source must be outside of the body. Raise an error.
+ if (normalized.empty()) {
+ return errors::InvalidArgument(
+ "The source of control edge ", edge->DebugString(),
+ " is not in the body. Encountered while creating function '",
+ fn_name, "'");
+ }
+ node_def->add_input(strings::StrCat("^", normalized));
+ }
+ }
+ return Status::OK();
+}
+
+// Graph to FunctionDef conversion. This code is closely modeled on the Python
+// code in third_party/tensorflow/python/framework/function.py.
+Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
+ const std::vector<const Node*>& body_nodes,
+ const std::vector<OutputTensor>& inputs,
+ const std::vector<OutputTensor>& outputs,
+ const std::vector<string>& output_names,
+ FunctionDef* fdef) {
+ fdef->mutable_signature()->set_name(fn_name);
+
+ // Keep track of names we used and how we normalized them.
+ NodeNameMapping node_names;
+
+ // Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
+ // name we used in the function:
+ // - For input tensors:
+ // {flat_tensor_name -> normalized_name_of_src_node}
+ // e.g. {In:3 -> in}
+ // - For tensors produced by nodes in function's body:
+ // {flat_tensor_name -> nested_tensor_name}
+ // e.g. {Add:3 -> add_0:z:1}
+ std::unordered_map<string, string> tensor_renaming;
+
+ // Fill inputs in function's signature.
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ const Node* node = inputs[i].node;
+ int idx = inputs[i].index;
+ OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
+ argdef->set_type(node->output_type(idx));
+ const string& input_name = node_names.GetIOName(node->name());
+ argdef->set_name(input_name);
+ tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
+ }
+
+ // Fill outputs in function's signature.
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ const Node* node = outputs[i].node;
+ int idx = outputs[i].index;
+ OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
+ argdef->set_type(node->output_type(idx));
+ argdef->set_name(node_names.GetIOName(node->name()));
+ }
+
+ // Populate tensor_renaming and node_names.
+ // Generate the new output names for every node in the function.
+ // The NodeDefs in FunctionDefs use a different naming scheme for
+ // their inputs than the NodeDefs in a graph (see the comment for
+ // FunctionDef.node_def in function.proto). We do the
+ // graph tensor name -> function tensor name conversion for every
+ // possible input (i.e. every node's outputs) and store the result
+ // in tensor_renaming.
+ for (const Node* node : body_nodes) {
+ // Make sure node_name does not collide with an input or output name.
+ const string& node_name = node_names.Uniquify(node->name());
+ // For each output_arg in the op_def, the output_ranges
+ // map will have [start, end] range of indices that this arg produces
+ // among all the output tensors of this op.
+ NameRangeMap output_ranges;
+ TF_RETURN_IF_ERROR(
+ NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
+ for (const auto& output : output_ranges) {
+ const string& output_name = output.first;
+ int index_start = output.second.first;
+ int index_end = output.second.second;
+ for (int i = index_start; i < index_end; ++i) {
+ const string& original_name = strings::StrCat(node->name(), ":", i);
+ const string& new_name =
+ strings::StrCat(node_name, ":", output_name, ":", i - index_start);
+ // Record the mapping if this tensor is not already mapped.
+ // Tensor can be already mapped if it is used as an input.
+ if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
+ tensor_renaming[original_name] = new_name;
+ }
+ }
+ }
+ }
+
+ TF_RETURN_IF_ERROR(
+ FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef));
+
+ // Remap return values.
+ for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
+ const string& ret_name = fdef->signature().output_arg(r).name();
+
+ // We convert this flat tensor name to the nested value
+ // (e.g. `add:z:1`) that we stored in tensor_renaming.
+ const string& return_value =
+ strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
+ const auto iter = tensor_renaming.find(return_value);
+ if (iter == tensor_renaming.end()) {
+ return errors::InvalidArgument(
+ "TF_Output ", return_value, " is neither in the function body ",
+ "nor among function inputs. Encountered while creating function '",
+ fn_name, "'");
+ }
+ (*fdef->mutable_ret())[ret_name] = iter->second;
+ }
+
+ return Status::OK();
+}
+
+// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
+// does various checks while doing so. `input_nodes` will contain the same
+// information as input_tensors just in a different structure to make
+// following processing easier. TODO(iga): Simplify this nested structure.
+Status ProcessInputs(
+ const TF_Graph* fn_body, const char* fn_name, int ninputs,
+ const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
+ std::unordered_map<const Node*, std::vector<int>>* input_nodes)
+ EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
+ input_tensors->reserve(ninputs);
+ for (int i = 0; i < ninputs; ++i) {
+ const Node& node = inputs[i].oper->node;
+ int idx = inputs[i].index;
+
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ fn_body->graph.IsValidOutputTensor(&node, idx),
+ "Encountered while processing input ", i, " into function '", fn_name,
+ "'");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(&node),
+ "Encountered while processing input ", i,
+ " into function '", fn_name, "'");
+
+ input_tensors->emplace_back(&node, idx);
+
+ const auto& iter = input_nodes->find(&node);
+ if (iter == input_nodes->end()) {
+ input_nodes->insert({&node, {idx}});
+ } else {
+ auto& indices = iter->second;
+ if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
+ return errors::InvalidArgument(
+ "TF_Output ", node.name(), ":", idx,
+ " appears more than once in the input list");
+ }
+ indices.push_back(idx);
+ }
+ }
+ return Status::OK();
+}
+
+// Converts `noutputs` and `outputs` into `outputs_tensors` and does various
+// checks while doing so.
+Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
+ int noutputs, const TF_Output* outputs,
+ std::vector<OutputTensor>* output_tensors)
+ EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
+ output_tensors->reserve(noutputs);
+ for (int i = 0; i < noutputs; ++i) {
+ const Node& node = outputs[i].oper->node;
+ int idx = outputs[i].index;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ fn_body->graph.IsValidOutputTensor(&node, idx),
+ "Encountered while processing output ", i, " from function '", fn_name,
+ "'");
+ output_tensors->emplace_back(&node, idx);
+ }
+ return Status::OK();
+}
+
+// Populates `body_nodes` with the nodes that will become function's body.
+// Performs various checks.
+Status ComputeBodyNodes(
+ const TF_Graph* fn_body, const char* fn_name, int num_opers,
+ const TF_Operation* const* opers,
+ const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
+ std::vector<const Node*>* body_nodes)
+ EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
+ if (num_opers == -1) {
+ for (const Node* node : fn_body->graph.op_nodes()) {
+ const auto& iter = input_nodes.find(node);
+ if (iter == input_nodes.end()) {
+ // This node is not referenced in inputs. Add it to the body.
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
+ "Encountered while creating function '",
+ fn_name, "'");
+ body_nodes->push_back(node);
+ } else {
+ // This node is referenced in inputs. Currently, we place an
+ // artificial restriction and require that when num_opers=-1, such
+ // nodes must have a single output.
+ if (node->num_outputs() != 1) {
+ return errors::InvalidArgument(
+ "When `num_opers` is set to -1, nodes referenced in `inputs` "
+ "must have a single output. Node ",
+ node->name(), " has ", node->num_outputs(),
+ " outputs. Encountered while creating function '", fn_name, "'");
+ }
+ }
+ }
+ } else {
+ body_nodes->reserve(num_opers);
+ for (int i = 0; i < num_opers; ++i) {
+ const Node* node = &opers[i]->node;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
+ "Encountered while creating function '",
+ fn_name, "'");
+ body_nodes->push_back(node);
+ }
+ }
+ return Status::OK();
+}
+
+} // anonymous namespace
+} // namespace tensorflow
+
+using tensorflow::Node;
+using tensorflow::string;
+
+TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
+ int num_opers, const TF_Operation* const* opers,
+ int ninputs, const TF_Output* inputs,
+ int noutputs, const TF_Output* outputs,
+ const char* const* output_names,
+ const TF_FunctionOptions* opts,
+ TF_Status* status) {
+ tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
+
+ // Process inputs.
+ std::vector<tensorflow::OutputTensor> input_tensors;
+ std::unordered_map<const Node*, std::vector<int>> input_nodes;
+ status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
+ &input_tensors, &input_nodes);
+ if (!status->status.ok()) return nullptr;
+
+ // Process outputs.
+ std::vector<tensorflow::OutputTensor> output_tensors;
+ status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
+ outputs, &output_tensors);
+ if (!status->status.ok()) return nullptr;
+
+ // Process output names.
+ std::vector<string> output_names_vec;
+ if (output_names) {
+ output_names_vec.reserve(noutputs);
+ for (int i = 0; i < noutputs; ++i) {
+ output_names_vec.push_back(string(output_names[i]));
+ }
+ }
+
+ // Compute body nodes.
+ std::vector<const Node*> body_nodes;
+ status->status = tensorflow::ComputeBodyNodes(
+ fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
+ if (!status->status.ok()) return nullptr;
+
+ // Do the actual function creation.
+ TF_Function* tf_function = new TF_Function();
+ status->status = tensorflow::GraphToFunctionDef(
+ fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors,
+ output_names_vec, tf_function->fdef_lib.add_function());
+ if (!status->status.ok()) {
+ TF_DeleteFunction(tf_function);
+ return nullptr;
+ }
+ return tf_function;
+}
+
+void TF_GraphAddFunction(TF_Graph* g, const TF_Function* function,
+ TF_Status* status) {
+ tensorflow::mutex_lock l(g->mu);
+
+ // At the moment, we have only one function and no gradients in fdef_lib.
+ // This makes the following operation atomic.
+ // TODO(iga): Add an atomic version of AddFunctionLibrary when we support
+ // gradients
+ status->status = g->graph.AddFunctionLibrary(function->fdef_lib);
+}
+
+void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
+ TF_Status* status) {
+ DCHECK_EQ(1, func->fdef_lib.function_size());
+ status->status = MessageToBuffer(func->fdef_lib.function(0), output_func_def);
+}
+
+void TF_DeleteFunction(TF_Function* function) { delete function; }