aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-08-30 21:05:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 21:08:53 -0700
commit9624d165f1f2c717eda96464fee8bf7229cc14f5 (patch)
tree8024d708b58b0c78f19d4c3cfc9f7c4b0c24b70c /tensorflow
parent424aa9aa9559f6fa29d8ccf3d74ff25528b39209 (diff)
Add function support to Tensorflow C API
This change adds minimal functionality. Support for FunctionOptions, attributes, output name rewriting, function name generation, etc is comming next. PiperOrigin-RevId: 167091238
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/c/BUILD24
-rw-r--r--tensorflow/c/c_api.cc37
-rw-r--r--tensorflow/c/c_api.h116
-rw-r--r--tensorflow/c/c_api_function.cc496
-rw-r--r--tensorflow/c/c_api_function_test.cc1039
-rw-r--r--tensorflow/c/c_api_internal.h8
-rw-r--r--tensorflow/c/c_api_test.cc2
-rw-r--r--tensorflow/c/c_test_util.cc131
-rw-r--r--tensorflow/c/c_test_util.h20
-rw-r--r--tensorflow/contrib/cmake/tf_c.cmake1
-rw-r--r--tensorflow/core/graph/graph.cc13
-rw-r--r--tensorflow/core/graph/graph.h4
-rw-r--r--tensorflow/python/client/tf_session.i27
-rw-r--r--tensorflow/python/client/tf_session_helper.cc34
-rw-r--r--tensorflow/python/client/tf_session_helper.h10
-rw-r--r--tensorflow/python/framework/function.py19
-rw-r--r--tensorflow/python/framework/function_test.py112
-rw-r--r--tensorflow/python/framework/ops.py8
18 files changed, 2072 insertions, 29 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 604dfab148..1822e235eb 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -45,8 +45,13 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api",
- srcs = ["c_api.cc"],
- hdrs = ["c_api.h"],
+ srcs = [
+ "c_api.cc",
+ "c_api_function.cc",
+ ],
+ hdrs = [
+ "c_api.h",
+ ],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
@@ -158,6 +163,21 @@ tf_cc_test(
)
tf_cc_test(
+ name = "c_api_function_test",
+ size = "small",
+ srcs = ["c_api_function_test.cc"],
+ deps = [
+ ":c_api",
+ ":c_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
name = "while_loop_test",
size = "small",
srcs = ["while_loop_test.cc"],
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 07c8277a6f..c454c94249 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -165,22 +165,6 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
tensorflow::cpu_allocator()->DeallocateRaw(data);
}
-Status MessageToBuffer(const tensorflow::protobuf::Message& in,
- TF_Buffer* out) {
- if (out->data != nullptr) {
- return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
- }
- const auto proto_size = in.ByteSizeLong();
- void* buf = tensorflow::port::Malloc(proto_size);
- in.SerializeToArray(buf, proto_size);
- out->data = buf;
- out->length = proto_size;
- out->data_deallocator = [](void* data, size_t length) {
- tensorflow::port::Free(data);
- };
- return Status::OK();
-}
-
} // namespace
TF_Tensor::~TF_Tensor() { buffer->Unref(); }
@@ -559,6 +543,27 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
dimvec.size(), base, size, DeleteArray, base);
}
+Status MessageToBuffer(const tensorflow::protobuf::Message& in,
+ TF_Buffer* out) {
+ if (out->data != nullptr) {
+ return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
+ }
+ const size_t proto_size = in.ByteSizeLong();
+ void* buf = tensorflow::port::Malloc(proto_size);
+ if (buf == nullptr) {
+ return tensorflow::errors::ResourceExhausted(
+ "Failed to allocate memory to serialize message of type '",
+ in.GetTypeName(), "' and size ", proto_size);
+ }
+ in.SerializeToArray(buf, proto_size);
+ out->data = buf;
+ out->length = proto_size;
+ out->data_deallocator = [](void* data, size_t length) {
+ tensorflow::port::Free(data);
+ };
+ return Status::OK();
+}
+
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 43b5078013..ee110d88ce 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -357,6 +357,14 @@ typedef struct TF_Output {
int index; // The index of the output within oper.
} TF_Output;
+// TF_Function is a grouping of operations with defined inputs and outputs.
+// Once created and added to graphs, functions can be invoked by creating an
+// operation whose operation type matches the function name.
+typedef struct TF_Function TF_Function;
+
+// Function definition options. TODO(iga): Define and implement
+typedef struct TF_FunctionOptions TF_FunctionOptions;
+
// Sets the shape of the Tensor referenced by `output` in `graph` to
// the shape described by `dims` and `num_dims`.
//
@@ -914,6 +922,15 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDef(
TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options, TF_Status* status);
+// Add `function` to graph `g`. Once `function` is added to `g`,
+// it can be called by creating an operation using the function's name.
+//
+// If successful, status is set to OK and function is added to g
+// Otherwise, status is set to the encountered error and g is unmodified
+TF_CAPI_EXPORT extern void TF_GraphAddFunction(TF_Graph* g,
+ const TF_Function* function,
+ TF_Status* status);
+
// Note: The following function may fail on very large protos in the future.
TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper,
@@ -1001,6 +1018,105 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
TF_Output* x, int nx, TF_Output* dx,
TF_Status* status, TF_Output* dy);
+// Create a TF_Function from a TF_Graph
+//
+// Params:
+// fn_body - the graph whose operations (or subset of whose operations) will be
+// converted to TF_Function.
+// fn_name - the name of the new TF_Function. Should match the operation
+// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct
+// from other operation names (at least those registered in graphs
+// where this function will be used).
+// TODO(iga): Allow null in here and have C API come up with
+// a unique name with high probability (similarly to
+// _create_hash_str in function.py)
+// num_opers - `num_opers` contains the number of elements in the `opers` array
+// or a special value of -1 meaning that no array is given.
+// The distinction between an empty array of operations and no
+// array of operations is necessary to distinguish the case of
+// creating a function with no body (e.g. identity or permutation)
+// and the case of creating a function whose body contains all
+// the nodes in the graph (except for the automatic skipping, see
+// below).
+// opers - Array of operations to become the body of the function or null.
+// - If no array is given (`num_opers` = -1), all the
+// operations in `fn_body` will become part of the function
+// except operations referenced in `inputs`. These operations
+// must have a single output (these operations are typically
+// placeholders created for the sole purpose of representing
+// an input. We can relax this constraint if there are
+// compelling use cases).
+// - If an array is given (`num_opers` >= 0), all operations
+// in it will become part of the function. In particular, no
+// automatic skipping of dummy input operations is performed.
+// ninputs - number of elements in `inputs` array
+// inputs - array of TF_Outputs that specify the inputs to the function.
+// If `ninputs` is zero (the function takes no inputs), `inputs`
+// can be null. The names used for function inputs are normalized
+// names of the operations (usually placeholders) pointed to by
+// `inputs`. These operation names should start with a letter.
+// Normalization will convert all letters to lowercase and
+// non-alphanumeric characters to '_' to make resulting names match
+// the "[a-z][a-z0-9_]*" pattern for operation argument names.
+// `inputs` cannot contain the same tensor twice.
+// noutputs - number of elements in `outputs` array
+// outputs - array of TF_Outputs that specify the outputs of the function.
+// If `noutputs` is zero (the function returns no outputs), `outputs`
+// can be null. `outputs` can contain the same tensor more than once.
+// output_names - The names of the function's outputs. `output_names` array
+// must either have the same length as `outputs`
+// (i.e. `noutputs`) or be null. In the former case,
+// the names should match the regular expression for ArgDef
+// names - "[a-z][a-z0-9_]*". In the latter case,
+// names for outputs will be generated automatically.
+// opts - various options for the function, e.g. XLA's inlining control.
+// status - Set to OK on success and an appropriate error on failure.
+//
+// Note that when the same TF_Output is listed as both an input and an output,
+// the corresponding function's output will equal to this input,
+// instead of the original node's output.
+//
+// Callers must also satisfy the following constraints:
+// - `inputs` cannot refer to TF_Outputs within a control flow context. For
+// example, one cannot use the output of "switch" node as input.
+// - No TF_Output of a function (inside any of `inputs`, `outputs`, `fn_body`)
+// is allowed to have a reference type. Reference types are not exposed
+// through C API and are being deprecated.
+// - Every node in the function's body must have all of its inputs (including
+// control inputs). In other words, for every node in the body, each input
+// must be either listed in `inputs` or must come from another node in
+// the body. In particular, it is an error to have a control edge going from
+// a node outside of the body into a node in the body. This applies to control
+// edges going from nodes referenced in `inputs` to nodes in the body when
+// the former nodes are not in the body (automatically skipped or not
+// included in explicitly specified body).
+//
+// Returns:
+// On successful, a newly created TF_Function instance. It must be deleted by
+// calling TF_DeleteFunction.
+//
+// On failure, null.
+//
+// TODO(iga): Add input_names argument and get output_names working (they are
+// currently ignored)
+TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction(
+ const TF_Graph* fn_body, const char* fn_name, int num_opers,
+ const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
+ int noutputs, const TF_Output* outputs, const char* const* output_names,
+ const TF_FunctionOptions* opts, TF_Status* status);
+
+// Write out a serialized representation of `func` (as a FunctionDef protocol
+// message) to `output_func_def` (allocated by TF_NewBuffer()).
+// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer()
+// is called.
+//
+// May fail on very large graphs in the future.
+TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func,
+ TF_Buffer* output_func_def,
+ TF_Status* status);
+
+TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function*);
+
// TODO(josh11b): Register OpDef, available to all operations added
// to this graph.
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
new file mode 100644
index 0000000000..b4c6397d0b
--- /dev/null
+++ b/tensorflow/c/c_api_function.cc
@@ -0,0 +1,496 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/c_api_internal.h"
+
+#include <algorithm>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+// Class that maintains a one-to-one original node name -> new node name
+// mapping. We normalize the names used as input and output arguments to match
+// regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
+// Once we rename them, we risk creating a name collision with the other
+// node names, so if necessary we add a suffix to make
+// names unique. If we have an input named "A" and a node in the function
+// body named "a", they will be renamed to "a" and "a_0".
+class NodeNameMapping {
+ public:
+ NodeNameMapping() = default;
+
+ // Normalize the input/output name and make it unique.
+ string GetIOName(const string& name);
+
+ // Make the node name unique.
+ string Uniquify(const string& name);
+
+ // Look up how a node name was previously normalized/uniquified.
+ // Returns empty if name was never seen.
+ string Lookup(const string& name) const;
+
+ private:
+ string UniquifyHelper(const string& name) const;
+ static string Normalize(string name);
+
+ // The normalized/uniquified names already used as
+ // input names (in signature), output names (in signature), and node names
+ // (in node_def).
+ // This is a superset of values in name_mapping_.
+ std::unordered_set<string> used_names_;
+ // Mapping from original node name from the graph to the normalized
+ // and uniqified version of it.
+ std::unordered_map<string, string> name_mapping_;
+};
+
+string NodeNameMapping::Normalize(string name) {
+ // Convert letters to lowercase and non-alphanumeric characters to '_'.
+ if (name.empty()) return "unknown";
+ const int n = name.size();
+ for (int i = 0; i < n; ++i) {
+ char c = name[i];
+ if (isalnum(c)) {
+ if (isupper(c)) {
+ name[i] = tolower(c);
+ }
+ } else {
+ name[i] = '_';
+ }
+ }
+
+ // Find the first letter and start with it.
+ int i = 0;
+ for (; i < n; ++i) {
+ if (isalpha(name[i])) break;
+ }
+
+ // Return "unknown" if none of the name's chars were letters.
+ return i == n ? "unknown" : name.substr(i);
+}
+
+string NodeNameMapping::UniquifyHelper(const string& name) const {
+ // If the name hasn't been used yet, use it as-is.
+ if (used_names_.find(name) == used_names_.end()) return name;
+ // Add a suffix to name to make it unique.
+ for (int i = 0;; ++i) {
+ const string candidate = strings::StrCat(name, "_", i);
+ if (used_names_.find(candidate) == used_names_.end()) return candidate;
+ }
+}
+
+string NodeNameMapping::GetIOName(const string& name) {
+ const string& input_name = UniquifyHelper(Normalize(name));
+ // Record that we used this name, but don't add it to name_mapping_
+ // since this name is not for a node.
+ used_names_.insert(input_name);
+ return input_name;
+}
+
+string NodeNameMapping::Uniquify(const string& name) {
+ const string uniqued = UniquifyHelper(name);
+ name_mapping_[name] = uniqued;
+ used_names_.insert(uniqued);
+ return uniqued;
+}
+
+string NodeNameMapping::Lookup(const string& name) const {
+ const auto iter = name_mapping_.find(name);
+ if (iter == name_mapping_.end()) return string();
+ return iter->second;
+}
+
+Status ValidateNoRefOutputs(const Node* node) {
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ const DataType& dt = node->output_type(i);
+ if (IsRefType(dt)) {
+ return errors::InvalidArgument("Output ", i, " of node '", node->name(),
+ "' has a reference "
+ "type ",
+ DataTypeString(dt));
+ }
+ }
+ return Status::OK();
+}
+
+Status FillFunctionBody(
+ const string& fn_name, const NodeNameMapping& node_names,
+ const std::vector<const Node*>& body_nodes,
+ const std::unordered_map<string, string>& tensor_renaming,
+ FunctionDef* fdef) {
+ std::vector<const Edge*> in_edges;
+ std::vector<const Edge*> control_edges;
+ for (const Node* node : body_nodes) {
+ NodeDef* node_def = fdef->add_node_def();
+ // First, copy the node_def as is. We will patch it next.
+ *node_def = node->def();
+ if (!node->assigned_device_name().empty()) {
+ node_def->set_device(node->assigned_device_name());
+ }
+ node_def->set_name(node_names.Lookup(node->name()));
+
+ // Input names must be set based on nested names in tensor_renaming.
+ // Clear the flat input names we got from the original node_def
+ // from the graph.
+ node_def->clear_input();
+
+ // Collect regular and control inputs. Regular inputs are indexed
+ // by the index at which they come into the `node`. Control inputs
+ // don't follow any order.
+ in_edges.clear();
+ in_edges.resize(node->num_inputs(), nullptr);
+ control_edges.clear();
+ for (const Edge* edge : node->in_edges()) {
+ if (edge->src()->IsSource()) continue;
+ if (edge->IsControlEdge()) {
+ control_edges.push_back(edge);
+ } else {
+ in_edges[edge->dst_input()] = edge;
+ }
+ }
+
+ // Add regular inputs.
+ for (size_t i = 0; i < in_edges.size(); ++i) {
+ const Edge* edge = in_edges[i];
+ string original_input_name;
+ if (edge == nullptr) {
+ // A backedge might not appear as a regular Edge, but be only present
+ // in the node_def. Such edges are referred to as requested_inputs().
+ if (i >= node->requested_inputs().size()) {
+ return errors::InvalidArgument(
+ "Graph to be converted to function appears to be malformed. ",
+ "Node ", node->name(), " is missing input edge ", i);
+ }
+ original_input_name =
+ ParseTensorName(node->requested_inputs()[i]).ToString();
+ } else {
+ original_input_name =
+ strings::StrCat(edge->src()->name(), ":", edge->src_output());
+ }
+
+ const auto iter = tensor_renaming.find(original_input_name);
+ if (iter == tensor_renaming.end()) {
+ return errors::InvalidArgument(
+ "Input ", i, ", '", original_input_name, "', of node '",
+ node->name(), "' in function '", fn_name,
+ "' is not available. You might need to include it in inputs "
+ "or include its source node in the body");
+ }
+ node_def->add_input(iter->second);
+ }
+
+ // Add control inputs.
+ for (const Edge* edge : control_edges) {
+ // Add this control input only if the src node is in the body.
+ const string normalized = node_names.Lookup(edge->src()->name());
+ // If we did not find a name for the source of control edge, this
+ // source must be outside of the body. Raise an error.
+ if (normalized.empty()) {
+ return errors::InvalidArgument(
+ "The source of control edge ", edge->DebugString(),
+ " is not in the body. Encountered while creating function '",
+ fn_name, "'");
+ }
+ node_def->add_input(strings::StrCat("^", normalized));
+ }
+ }
+ return Status::OK();
+}
+
+// Graph to FunctionDef conversion. This code is closely modeled on the Python
+// code in third_party/tensorflow/python/framework/function.py.
+Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
+ const std::vector<const Node*>& body_nodes,
+ const std::vector<OutputTensor>& inputs,
+ const std::vector<OutputTensor>& outputs,
+ const std::vector<string>& output_names,
+ FunctionDef* fdef) {
+ fdef->mutable_signature()->set_name(fn_name);
+
+ // Keep track of names we used and how we normalized them.
+ NodeNameMapping node_names;
+
+ // Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
+ // name we used in the function:
+ // - For input tensors:
+ // {flat_tensor_name -> normalized_name_of_src_node}
+ // e.g. {In:3 -> in}
+ // - For tensors produced by nodes in function's body:
+ // {flat_tensor_name -> nested_tensor_name}
+ // e.g. {Add:3 -> add_0:z:1}
+ std::unordered_map<string, string> tensor_renaming;
+
+ // Fill inputs in function's signature.
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ const Node* node = inputs[i].node;
+ int idx = inputs[i].index;
+ OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
+ argdef->set_type(node->output_type(idx));
+ const string& input_name = node_names.GetIOName(node->name());
+ argdef->set_name(input_name);
+ tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
+ }
+
+ // Fill outputs in function's signature.
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ const Node* node = outputs[i].node;
+ int idx = outputs[i].index;
+ OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
+ argdef->set_type(node->output_type(idx));
+ argdef->set_name(node_names.GetIOName(node->name()));
+ }
+
+ // Populate tensor_renaming and node_names.
+ // Generate the new output names for every node in the function.
+ // The NodeDefs in FunctionDefs use a different naming scheme for
+ // their inputs than the NodeDefs in a graph (see the comment for
+ // FunctionDef.node_def in function.proto). We do the
+ // graph tensor name -> function tensor name conversion for every
+ // possible input (i.e. every node's outputs) and store the result
+ // in tensor_renaming.
+ for (const Node* node : body_nodes) {
+ // Make sure node_name does not collide with an input or output name.
+ const string& node_name = node_names.Uniquify(node->name());
+ // For each output_arg in the op_def, the output_ranges
+ // map will have [start, end] range of indices that this arg produces
+ // among all the output tensors of this op.
+ NameRangeMap output_ranges;
+ TF_RETURN_IF_ERROR(
+ NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
+ for (const auto& output : output_ranges) {
+ const string& output_name = output.first;
+ int index_start = output.second.first;
+ int index_end = output.second.second;
+ for (int i = index_start; i < index_end; ++i) {
+ const string& original_name = strings::StrCat(node->name(), ":", i);
+ const string& new_name =
+ strings::StrCat(node_name, ":", output_name, ":", i - index_start);
+ // Record the mapping if this tensor is not already mapped.
+ // Tensor can be already mapped if it is used as an input.
+ if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
+ tensor_renaming[original_name] = new_name;
+ }
+ }
+ }
+ }
+
+ TF_RETURN_IF_ERROR(
+ FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef));
+
+ // Remap return values.
+ for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
+ const string& ret_name = fdef->signature().output_arg(r).name();
+
+ // We convert this flat tensor name to the nested value
+ // (e.g. `add:z:1`) that we stored in tensor_renaming.
+ const string& return_value =
+ strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
+ const auto iter = tensor_renaming.find(return_value);
+ if (iter == tensor_renaming.end()) {
+ return errors::InvalidArgument(
+ "TF_Output ", return_value, " is neither in the function body ",
+ "nor among function inputs. Encountered while creating function '",
+ fn_name, "'");
+ }
+ (*fdef->mutable_ret())[ret_name] = iter->second;
+ }
+
+ return Status::OK();
+}
+
+// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
+// does various checks while doing so. `input_nodes` will contain the same
+// information as input_tensors just in a different structure to make
+// following processing easier. TODO(iga): Simplify this nested structure.
+Status ProcessInputs(
+ const TF_Graph* fn_body, const char* fn_name, int ninputs,
+ const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
+ std::unordered_map<const Node*, std::vector<int>>* input_nodes)
+ EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
+ input_tensors->reserve(ninputs);
+ for (int i = 0; i < ninputs; ++i) {
+ const Node& node = inputs[i].oper->node;
+ int idx = inputs[i].index;
+
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ fn_body->graph.IsValidOutputTensor(&node, idx),
+ "Encountered while processing input ", i, " into function '", fn_name,
+ "'");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(&node),
+ "Encountered while processing input ", i,
+ " into function '", fn_name, "'");
+
+ input_tensors->emplace_back(&node, idx);
+
+ const auto& iter = input_nodes->find(&node);
+ if (iter == input_nodes->end()) {
+ input_nodes->insert({&node, {idx}});
+ } else {
+ auto& indices = iter->second;
+ if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
+ return errors::InvalidArgument(
+ "TF_Output ", node.name(), ":", idx,
+ " appears more than once in the input list");
+ }
+ indices.push_back(idx);
+ }
+ }
+ return Status::OK();
+}
+
+// Converts `noutputs` and `outputs` into `outputs_tensors` and does various
+// checks while doing so.
+Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
+ int noutputs, const TF_Output* outputs,
+ std::vector<OutputTensor>* output_tensors)
+ EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
+ output_tensors->reserve(noutputs);
+ for (int i = 0; i < noutputs; ++i) {
+ const Node& node = outputs[i].oper->node;
+ int idx = outputs[i].index;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ fn_body->graph.IsValidOutputTensor(&node, idx),
+ "Encountered while processing output ", i, " from function '", fn_name,
+ "'");
+ output_tensors->emplace_back(&node, idx);
+ }
+ return Status::OK();
+}
+
+// Populates `body_nodes` with the nodes that will become function's body.
+// Performs various checks.
+Status ComputeBodyNodes(
+ const TF_Graph* fn_body, const char* fn_name, int num_opers,
+ const TF_Operation* const* opers,
+ const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
+ std::vector<const Node*>* body_nodes)
+ EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
+ if (num_opers == -1) {
+ for (const Node* node : fn_body->graph.op_nodes()) {
+ const auto& iter = input_nodes.find(node);
+ if (iter == input_nodes.end()) {
+ // This node is not referenced in inputs. Add it to the body.
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
+ "Encountered while creating function '",
+ fn_name, "'");
+ body_nodes->push_back(node);
+ } else {
+ // This node is referenced in inputs. Currently, we place an
+ // artificial restriction and require that when num_opers=-1, such
+ // nodes must have a single output.
+ if (node->num_outputs() != 1) {
+ return errors::InvalidArgument(
+ "When `num_opers` is set to -1, nodes referenced in `inputs` "
+ "must have a single output. Node ",
+ node->name(), " has ", node->num_outputs(),
+ " outputs. Encountered while creating function '", fn_name, "'");
+ }
+ }
+ }
+ } else {
+ body_nodes->reserve(num_opers);
+ for (int i = 0; i < num_opers; ++i) {
+ const Node* node = &opers[i]->node;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
+ "Encountered while creating function '",
+ fn_name, "'");
+ body_nodes->push_back(node);
+ }
+ }
+ return Status::OK();
+}
+
+} // anonymous namespace
+} // namespace tensorflow
+
+using tensorflow::Node;
+using tensorflow::string;
+
+TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
+ int num_opers, const TF_Operation* const* opers,
+ int ninputs, const TF_Output* inputs,
+ int noutputs, const TF_Output* outputs,
+ const char* const* output_names,
+ const TF_FunctionOptions* opts,
+ TF_Status* status) {
+ tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
+
+ // Process inputs.
+ std::vector<tensorflow::OutputTensor> input_tensors;
+ std::unordered_map<const Node*, std::vector<int>> input_nodes;
+ status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
+ &input_tensors, &input_nodes);
+ if (!status->status.ok()) return nullptr;
+
+ // Process outputs.
+ std::vector<tensorflow::OutputTensor> output_tensors;
+ status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
+ outputs, &output_tensors);
+ if (!status->status.ok()) return nullptr;
+
+ // Process output names.
+ std::vector<string> output_names_vec;
+ if (output_names) {
+ output_names_vec.reserve(noutputs);
+ for (int i = 0; i < noutputs; ++i) {
+ output_names_vec.push_back(string(output_names[i]));
+ }
+ }
+
+ // Compute body nodes.
+ std::vector<const Node*> body_nodes;
+ status->status = tensorflow::ComputeBodyNodes(
+ fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
+ if (!status->status.ok()) return nullptr;
+
+ // Do the actual function creation.
+ TF_Function* tf_function = new TF_Function();
+ status->status = tensorflow::GraphToFunctionDef(
+ fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors,
+ output_names_vec, tf_function->fdef_lib.add_function());
+ if (!status->status.ok()) {
+ TF_DeleteFunction(tf_function);
+ return nullptr;
+ }
+ return tf_function;
+}
+
+void TF_GraphAddFunction(TF_Graph* g, const TF_Function* function,
+ TF_Status* status) {
+ tensorflow::mutex_lock l(g->mu);
+
+ // At the moment, we have only one function and no gradients in fdef_lib.
+ // This makes the following operation atomic.
+ // TODO(iga): Add an atomic version of AddFunctionLibrary when we support
+ // gradients
+ status->status = g->graph.AddFunctionLibrary(function->fdef_lib);
+}
+
+void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
+ TF_Status* status) {
+ DCHECK_EQ(1, func->fdef_lib.function_size());
+ status->status = MessageToBuffer(func->fdef_lib.function(0), output_func_def);
+}
+
+void TF_DeleteFunction(TF_Function* function) { delete function; }
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
new file mode 100644
index 0000000000..c9dd38ea15
--- /dev/null
+++ b/tensorflow/c/c_api_function_test.cc
@@ -0,0 +1,1039 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/c_api.h"
+
+#include "tensorflow/c/c_test_util.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+// Specification for expected input/output and its type.
+// DataType value of DT_INVALID signifies that we don't want to
+// check the data type.
+typedef std::pair<string, DataType> IOSpec;
+
+std::vector<IOSpec> M(const std::initializer_list<string>& names) {
+ std::vector<IOSpec> v;
+ for (const string& name : names) {
+ v.push_back(IOSpec(name, DT_INVALID));
+ }
+ return v;
+}
+
+// Specification for an expected edge.
+// src is either:
+// - input name (as it appears in FunctionDef)
+// - name of output tensor (in nested "add:z:0" format)
+// dst is either:
+// - output name (as it appears in FunctionDef)
+// - <name_of_node>:<index_of_this_input_into_node> (this looks the same as
+// output tensor naming, but it the index is actually an input index)
+struct EdgeSpec : public std::pair<string, string> {
+ typedef std::pair<string, string> Base;
+
+ // Inherit the set of constructors
+ using Base::pair;
+
+ string ToString() const { return strings::StrCat(first, "->", second); }
+};
+
+class CApiFunctionTest : public ::testing::Test {
+ protected:
+ CApiFunctionTest()
+ : s_(TF_NewStatus()),
+ func_graph_(TF_NewGraph()),
+ host_graph_(TF_NewGraph()),
+ func_(nullptr) {}
+
+ void SetUp() override {}
+
+ ~CApiFunctionTest() override {
+ TF_DeleteFunction(func_);
+ TF_DeleteGraph(host_graph_);
+ TF_DeleteGraph(func_graph_);
+ TF_DeleteStatus(s_);
+ }
+
+ void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
+ TF_Operation* output, int32_t expected_result) {
+ Run(inputs, {{output, 0}}, {expected_result});
+ }
+
+ // Run the host graph, which now contains a function and check that
+ // outputs are as expected.
+ // 'T' stands for 'tensor' since the outputs are tensors, not scalars.
+ void RunT(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
+ std::initializer_list<TF_Output> outputs,
+ const std::vector<std::vector<int32_t>>& expected_results) {
+ // Create a session for this graph
+ CSession csession(host_graph_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Run
+ csession.SetInputs(inputs);
+ csession.SetOutputs(outputs);
+ csession.Run(s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Check results
+ for (int i = 0; i < expected_results.size(); ++i) {
+ TF_Tensor* out = csession.output_tensor(i);
+ ASSERT_TRUE(out != nullptr);
+ EXPECT_EQ(TF_INT32, TF_TensorType(out));
+ EXPECT_EQ(1, TF_NumDims(out));
+ CompareInt32Tensor(expected_results[i], out);
+ }
+ }
+
+ // Run the host graph, which now contains a function and check that
+ // outputs are as expected.
+ void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
+ std::initializer_list<TF_Output> outputs,
+ const std::vector<int32_t>& expected_results) {
+ // Create a session for this graph.
+ CSession csession(host_graph_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ csession.SetInputs(inputs);
+ csession.SetOutputs(outputs);
+ csession.Run(s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ for (int i = 0; i < expected_results.size(); ++i) {
+ TF_Tensor* out = csession.output_tensor(i);
+ ASSERT_TRUE(out != nullptr);
+ EXPECT_EQ(TF_INT32, TF_TensorType(out));
+ EXPECT_EQ(0, TF_NumDims(out)); // scalar
+ ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
+ int32_t* output_contents = static_cast<int32_t*>(TF_TensorData(out));
+ EXPECT_EQ(expected_results[i], *output_contents);
+ }
+ }
+
+ void CompareInt32Tensor(const std::vector<int32_t>& expected, TF_Tensor* t) {
+ int32_t* data = static_cast<int32_t*>(TF_TensorData(t));
+ size_t size = TF_TensorByteSize(t);
+ ASSERT_EQ(expected.size() * sizeof(int32_t), size);
+ for (int i = 0; i < expected.size(); ++i) {
+ ASSERT_EQ(expected[i], data[i]) << "Different data at index " << i;
+ }
+ }
+
+ std::vector<TF_Output> ToOutput(const std::vector<TF_Operation*> ops) {
+ std::vector<TF_Output> out;
+ for (auto op : ops) {
+ out.push_back({op, 0});
+ }
+ return out;
+ }
+
+ void Define(int num_opers, const std::vector<TF_Operation*>& opers,
+ const std::vector<TF_Operation*>& inputs,
+ const std::vector<TF_Operation*>& outputs,
+ const char** output_names, bool expect_failure = false) {
+ DefineT(num_opers, opers, ToOutput(inputs), ToOutput(outputs), output_names,
+ expect_failure);
+ }
+
+ // An explicit `num_opers` is needed so that we can distinguish between the
+ // case of no operations specified (-1) and the case of an empty set of
+ // operations specified (0).
+ void DefineT(int num_opers, const std::vector<TF_Operation*>& opers,
+ const std::vector<TF_Output>& inputs,
+ const std::vector<TF_Output>& outputs, const char** output_names,
+ bool expect_failure = false) {
+ ASSERT_EQ(func_, nullptr);
+ func_ = TF_GraphToFunction(func_graph_, func_name_, num_opers,
+ num_opers == -1 ? nullptr : opers.data(),
+ inputs.size(), inputs.data(), outputs.size(),
+ outputs.data(), output_names,
+ /*opts=*/nullptr, s_);
+ if (expect_failure) {
+ ASSERT_EQ(func_, nullptr);
+ return;
+ }
+
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ ASSERT_NE(func_, nullptr);
+ TF_GraphAddFunction(host_graph_, func_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ }
+
+ TF_Operation* Use(const std::vector<TF_Operation*>& inputs) {
+ return UseT(ToOutput(inputs));
+ }
+
+ TF_Operation* UseT(const std::vector<TF_Output>& inputs) {
+ TF_Operation* op;
+ UseHelper(inputs, &op);
+ return op;
+ }
+
+ // All the *Helper methods are used as a workaround for the restrictions that
+ // one cannot call ASSERT_* methods in non-void-returning functions (when
+ // exceptions are disabled during compilation)
+ void UseHelper(const std::vector<TF_Output>& inputs, TF_Operation** op) {
+ TF_OperationDescription* desc =
+ TF_NewOperation(host_graph_, func_name_, func_node_name_);
+ for (auto input : inputs) {
+ TF_AddInput(desc, input);
+ }
+ // Set device to CPU because some ops inside the function might not be
+ // available on GPU.
+ TF_SetDevice(desc, "/cpu:0");
+ *op = TF_FinishOperation(desc, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ ASSERT_NE(*op, nullptr);
+ }
+
+ FunctionDef fdef() {
+ tensorflow::FunctionDef fdef;
+ EXPECT_TRUE(GetFunctionDef(func_, &fdef));
+ return fdef;
+ }
+
+ // logging utility
+ template <class Container>
+ string ToString(const Container& v) {
+ std::stringstream ss;
+ ss << "{";
+ size_t i = 0;
+ for (const auto& e : v) {
+ if (i != 0) {
+ ss << ", ";
+ }
+ ss << e.ToString();
+ ++i;
+ }
+ ss << "}";
+ return ss.str();
+ }
+
+ void VerifyFDefNodes(const tensorflow::FunctionDef& fdef,
+ const std::unordered_set<string>& nodes) {
+ ASSERT_EQ(nodes.size(), fdef.node_def_size())
+ << "Got unexpected number of nodes. Expected: ["
+ << str_util::Join(nodes, ", ")
+ << "] Actual nodes in fdef: " << fdef.DebugString();
+ for (const NodeDef& node_def : fdef.node_def()) {
+ ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
+ << "Got unexpected node: " << node_def.name()
+ << " in fdef: " << fdef.DebugString();
+ }
+ }
+
+ void VerifyFDefInputs(const tensorflow::FunctionDef& fdef,
+ const std::vector<IOSpec>& inputs) {
+ const OpDef& signature = fdef.signature();
+ ASSERT_EQ(inputs.size(), signature.input_arg_size());
+ for (int i = 0; i < inputs.size(); ++i) {
+ const OpDef::ArgDef& arg = signature.input_arg(i);
+ const IOSpec& in = inputs[i];
+ if (in.second != DT_INVALID) {
+ ASSERT_EQ(arg.type(), in.second)
+ << "Got unexpected type for input " << i
+ << ". fdef: " << fdef.DebugString();
+ }
+ ASSERT_EQ(arg.name(), in.first) << "Got unexpected name for input " << i
+ << ". fdef: " << fdef.DebugString();
+ }
+ }
+
+ void VerifyFDefOutputs(const tensorflow::FunctionDef& fdef,
+ const std::vector<IOSpec>& outputs) {
+ const OpDef& signature = fdef.signature();
+ ASSERT_EQ(outputs.size(), signature.output_arg_size());
+ for (int i = 0; i < outputs.size(); ++i) {
+ const OpDef::ArgDef& arg = signature.output_arg(i);
+ const IOSpec& out = outputs[i];
+ if (out.second != DT_INVALID) {
+ ASSERT_EQ(arg.type(), out.second)
+ << "Got unexpected type for output " << i
+ << ". fdef: " << fdef.DebugString();
+ }
+ ASSERT_EQ(arg.name(), out.first) << "Got unexpected name for output " << i
+ << ". fdef: " << fdef.DebugString();
+ }
+ }
+
+ void VerifyFDefEdges(
+ const tensorflow::FunctionDef& fdef,
+ const std::vector<EdgeSpec>& e_edges, // expected edges
+ const std::vector<EdgeSpec>& c_edges, // expected ctrl edges
+ bool is_exact_edges = true) {
+ // Build a set of edges from fdef
+ std::set<EdgeSpec> a_edges; // actual edges
+ // Get edges from inputs to body nodes and between body nodes
+ for (const NodeDef& node_def : fdef.node_def()) {
+ for (int i = 0; i < node_def.input_size(); ++i) {
+ const string& in = node_def.input(i);
+ const auto& v =
+ a_edges.insert({in, strings::StrCat(node_def.name(), ":", i)});
+ ASSERT_TRUE(v.second) << "Duplicate edge " << in << " -> "
+ << strings::StrCat(node_def.name(), ":", i)
+ << ". fdef: " << fdef.DebugString();
+ }
+ }
+ // Get edges from body nodes to outputs and from inputs to outputs
+ for (const OpDef::ArgDef& arg : fdef.signature().output_arg()) {
+ const auto& iter = fdef.ret().find(arg.name());
+ if (iter != fdef.ret().end()) {
+ const auto& v = a_edges.insert({iter->second, arg.name()});
+ ASSERT_TRUE(v.second) << "Duplicate edge " << iter->second << " -> "
+ << arg.name() << ". fdef: " << fdef.DebugString();
+ } else {
+ const auto& v = a_edges.insert({arg.name(), arg.name()});
+ ASSERT_TRUE(v.second) << "Duplicate edge " << arg.name() << " -> "
+ << arg.name() << ". fdef: " << fdef.DebugString();
+ }
+ }
+
+ // Verify edges
+ for (const EdgeSpec& e : e_edges) {
+ ASSERT_TRUE(a_edges.find(e) != a_edges.end())
+ << "Failed to find expected edge " << e.ToString()
+ << " in fdef: " << fdef.DebugString();
+ }
+
+ // If caller specified all edges, check that we have seen all
+ if (is_exact_edges) {
+ ASSERT_EQ(e_edges.size() + c_edges.size(), a_edges.size())
+ << "Expected edges: " << ToString(e_edges)
+ << " Expected Control edges: " << ToString(c_edges)
+ << " Actual edges: " << ToString(a_edges)
+ << " in fdef: " << fdef.DebugString();
+ }
+ }
+
+ void VerifyFDef(const std::unordered_set<string>& nodes,
+ const std::vector<IOSpec>& inputs,
+ const std::vector<IOSpec>& outputs,
+ const std::vector<EdgeSpec>& e_edges, // expected edges
+ const std::vector<EdgeSpec>& c_edges, // expected ctrl edges
+ bool is_exact_edges = true) {
+ tensorflow::FunctionDef fdef;
+ ASSERT_TRUE(GetFunctionDef(func_, &fdef));
+ VerifyFDefNodes(fdef, nodes);
+ VerifyFDefInputs(fdef, inputs);
+ VerifyFDefOutputs(fdef, outputs);
+ VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges);
+ }
+
+ const char* func_name_ = "MyFunc";
+ const char* func_node_name_ = "MyFunc_0";
+ TF_Status* s_;
+ TF_Graph* func_graph_;
+ TF_Graph* host_graph_;
+ TF_Function* func_;
+
+ // Workaround for not being able to initialize empty map using {}
+ std::unordered_set<string> empty_;
+};
+
+TEST_F(CApiFunctionTest, OneOp_ZeroInputs_OneOutput) {
+ /*
+ * constant
+ * |
+ * v
+ */
+ // Define
+ TF_Operation* c = ScalarConst(10, func_graph_, s_, "scalar10");
+ Define(-1, {}, {}, {c}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* func_op = Use({});
+ Run({}, func_op, 10);
+ VerifyFDef({"scalar10_0"}, {}, {{"scalar10", DT_INT32}},
+ {{"scalar10_0:output:0", "scalar10"}}, {});
+}
+
+TEST_F(CApiFunctionTest, OneOp_OneInput_OneOutput) {
+ /*
+ * |
+ * v
+ * negate
+ * |
+ * v
+ */
+ // Define
+ TF_Operation* feed = Placeholder(func_graph_, s_);
+ TF_Operation* neg = Neg(feed, func_graph_, s_);
+ Define(-1, {}, {feed}, {neg}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
+ VerifyFDef({"neg_0"}, {{"feed", DT_INT32}}, {{"neg", DT_INT32}},
+ {{"feed", "neg_0:0"}, {"neg_0:y:0", "neg"}}, {});
+}
+
+TEST_F(CApiFunctionTest, ZeroOps_Identity) {
+ /*
+ * |
+ * |
+ * |
+ * v
+ */
+ // Define
+ TF_Operation* feed = Placeholder(func_graph_, s_);
+ Define(-1, {}, {feed}, {feed}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, func_op, 3);
+ VerifyFDef(empty_, {{"feed", DT_INT32}}, {{"feed_0", DT_INT32}},
+ {{"feed", "feed_0"}}, {});
+}
+
+TEST_F(CApiFunctionTest, ZeroOps_Permutation) {
+ /*
+ * | |
+ * \ /
+ * \/
+ * x
+ * /\
+ * / \
+ * | |
+ * v v
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ Define(-1, {}, {feed1, feed2}, {feed2, feed1}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_);
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({two, func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
+ VerifyFDef(empty_, M({{"feed1"}, {"feed2"}}), M({{"feed2_0"}, {"feed1_0"}}),
+ {{"feed1", "feed1_0"}, {"feed2", "feed2_0"}}, {});
+}
+
+TEST_F(CApiFunctionTest, OneOp_TwoInputs_OneOutput) {
+ /*
+ * | |
+ * v v
+ * add
+ * |
+ * v
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
+ Define(-1, {}, {feed1, feed2}, {add}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_);
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({two, func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
+ VerifyFDef(
+ {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
+ {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, {});
+}
+
+TEST_F(CApiFunctionTest, OneOp_TwoInputs_ZeroOutputs) {
+ /*
+ * | |
+ * v v
+ * add
+ *
+ * (output ignored)
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ Add(feed1, feed2, func_graph_, s_);
+ Define(-1, {}, {feed1, feed2}, {}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_);
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ Use({two, func_feed});
+ VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), {},
+ {{"feed1", "add:0"}, {"feed2", "add:1"}}, {});
+}
+
+TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_OneOutput) {
+ /*
+ * | | |
+ * v v /
+ * add1 /
+ * | |
+ * v v
+ * add2
+ * |
+ * v
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
+ TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
+ TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
+ Define(-1, {}, {feed1, feed2, feed3}, {add2}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
+ TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({two, ten, func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 10 + 3);
+ VerifyFDef({"add1", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
+ M({{"add2"}}),
+ {{"feed1", "add1:0"},
+ {"feed2", "add1:1"},
+ {"add1:sum:0", "add2_0:0"},
+ {"feed3", "add2_0:1"},
+ {"add2_0:sum:0", "add2"}},
+ {});
+}
+
+TEST_F(CApiFunctionTest, OneOp_TwoInputs_TwoDuplicateOutputs) {
+ /*
+ * | |
+ * v v
+ * add
+ * |
+ * +-+-+
+ * | |
+ * v v
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
+ Define(-1, {}, {feed1, feed2}, {add, add}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_);
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({two, func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
+ VerifyFDef({"add_1"}, M({{"feed1"}, {"feed2"}}), M({{"add"}, {"add_0"}}),
+ {{"feed1", "add_1:0"},
+ {"feed2", "add_1:1"},
+ {"add_1:sum:0", "add"},
+ {"add_1:sum:0", "add_0"}},
+ {});
+}
+
+TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_TwoOutputs) {
+ /*
+ * | | |
+ * v v /
+ * add /
+ * | |
+ * +-+ |
+ * | | |
+ * | v v
+ * | add
+ * | |
+ * v v
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
+ TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
+ TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
+ Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
+ TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({two, ten, func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
+ VerifyFDef({"add1_0", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
+ M({{"add1"}, {"add2"}}),
+ {{"feed1", "add1_0:0"},
+ {"feed2", "add1_0:1"},
+ {"add1_0:sum:0", "add2_0:0"},
+ {"feed3", "add2_0:1"},
+ {"add1_0:sum:0", "add1"},
+ {"add2_0:sum:0", "add2"}},
+ {});
+}
+
+TEST_F(CApiFunctionTest, FromSubsetOfOps) {
+ /*
+ * | | |
+ * v v /
+ * add /
+ * | |
+ * +---+--+---+
+ * Ops used | | | |
+ * for func | v v |
+ * | | add |
+ * +-------> | | |
+ * | v |
+ * | |
+ * +----------+
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
+ TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
+ TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
+ Define(1, {add2}, {add1, feed3}, {add2}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({two, func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
+ VerifyFDef(
+ {"add2_0"}, M({{"add1"}, {"feed3"}}), M({{"add2"}}),
+ {{"add1", "add2_0:0"}, {"feed3", "add2_0:1"}, {"add2_0:sum:0", "add2"}},
+ {});
+}
+
+TEST_F(CApiFunctionTest, UsingOneOutputOfSplit) {
+ /*
+ * feed
+ * |
+ * +---------+---+
+ * | const0 | |
+ * | | | |
+ * | v / |
+ * | split |
+ * | | | | |
+ * | v | v |
+ * | | |
+ * +------+------+
+ * |
+ * v
+ *
+ * Only the second output from split is used as function output
+ */
+ // Define
+ TF_Operation* feed = Placeholder(func_graph_, s_);
+ TF_Operation* split = Split3(feed, func_graph_, s_);
+ DefineT(-1, {}, {{feed, 0}}, {{split, 1}}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({func_feed});
+ RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, {{func_op, 0}},
+ {{3, 4}});
+ VerifyFDef({"split3_const0", "split3_0"}, M({{"feed"}}), M({{"split3"}}),
+ {{"split3_const0:output:0", "split3_0:0"},
+ {"feed", "split3_0:1"},
+ {"split3_0:output:1", "split3"}},
+ {});
+}
+
+TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplit) {
+ /*
+ * feed
+ * |
+ * +---------+---+
+ * | const0 | |
+ * | | | |
+ * | v / |
+ * | split |
+ * | | | | |
+ * | | v | |
+ * | | | |
+ * +---+-----+---+
+ * | |
+ * v v
+ *
+ * Second output from split is not used as function output
+ */
+ // Define
+ TF_Operation* feed = Placeholder(func_graph_, s_);
+ TF_Operation* split = Split3(feed, func_graph_, s_);
+ DefineT(-1, {}, {{feed, 0}}, {{split, 0}, {split, 2}}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({func_feed});
+ RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}},
+ {{func_op, 0}, {func_op, 1}}, {{1, 2}, {5, 6}});
+ VerifyFDef({"split3_const0", "split3_1"}, M({{"feed"}}),
+ M({{"split3"}, {"split3_0"}}),
+ {{"split3_const0:output:0", "split3_1:0"},
+ {"feed", "split3_1:1"},
+ {"split3_1:output:0", "split3"},
+ {"split3_1:output:2", "split3_0"}},
+ {});
+}
+
+TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplitAsInputs) {
+ /*
+ * |
+ * v
+ * split
+ * | | |
+ * | v |
+ * | |
+ * +---+-----+---+
+ * | | | |
+ * | v v |
+ * | add |
+ * | | |
+ * | | |
+ * +------+------+
+ * |
+ * v
+ */
+ // Define
+ TF_Operation* feed = Placeholder(func_graph_, s_);
+ TF_Operation* split = Split3(feed, func_graph_, s_);
+ TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ DefineT(1, {add}, {{split, 0}, {split, 2}}, {{add, 0}}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({two, func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
+ VerifyFDef(
+ {"add_0"}, M({{"split3"}, {"split3_0"}}), M({{"add"}}),
+ {{"split3", "add_0:0"}, {"split3_0", "add_0:1"}, {"add_0:sum:0", "add"}},
+ {});
+}
+
+TEST_F(CApiFunctionTest, NodesUsedInInputsMustHaveSingleOutput) {
+ /*
+ * |
+ * v
+ * split
+ * | | |
+ * | v |
+ * | |
+ * input --->| |<--- input
+ * | |
+ * v v
+ * add
+ * |
+ * |
+ * v
+ */
+ // Define
+ TF_Tensor* tensor_123 = Int32Tensor({1, 2, 3});
+ TF_Operation* c = Const(tensor_123, func_graph_, s_, "const_array");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_Operation* split = Split3(c, func_graph_, s_);
+ TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ DefineT(-1, {}, {{split, 0}, {split, 2}}, {{add, 0}}, nullptr, true);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("When `num_opers` is set to -1, nodes referenced in "
+ "`inputs` must have a single output. Node split3 has "
+ "3 outputs. Encountered while creating function 'MyFunc'"),
+ string(TF_Message(s_)));
+
+ TF_DeleteTensor(tensor_123);
+}
+
+TEST_F(CApiFunctionTest, FunctionWithWhileLoop) {
+ // Inputs to the while loop and the function as a whole
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+
+ // Outputs of the while loop corresponding to the two inputs above
+ // The first one will the function's output
+ std::vector<TF_Output> outputs;
+
+ // Add while loop to func_graph_
+ {
+ // The inputs to the while loop
+ std::vector<TF_Output> inputs = {{feed1, 0}, {feed2, 0}};
+ std::unique_ptr<TF_WhileParams> params(new TF_WhileParams(
+ TF_NewWhile(func_graph_, &inputs[0], inputs.size(), s_)));
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params->name = "test_loop";
+
+ // Initialize outputs so we can easily detect errors/bugs
+ outputs.resize(2, {nullptr, -1});
+
+ // Create loop: while (input1 < input2) input1 += input2 + 1
+ TF_Operation* less_than = LessThan(
+ params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params->cond_output = {less_than, 0};
+
+ TF_Operation* add1 = Add(params->body_inputs[0], params->body_inputs[1],
+ params->body_graph, s_, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_Operation* one = ScalarConst(1, params->body_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_Operation* add2 = Add(add1, one, params->body_graph, s_, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params->body_outputs[0] = {add2, 0};
+ params->body_outputs[1] = params->body_inputs[1];
+
+ // Finalize while loop
+ TF_FinishWhile(params.get(), s_, &outputs[0]);
+ EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ }
+
+ // Define function, use it in graph, and run
+ DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {outputs[0]}, nullptr);
+ TF_Operation* five = ScalarConst(5, host_graph_, s_, "five");
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({func_feed, five});
+ Run({{func_feed, Int32Tensor(2)}}, func_op, 2 /*+=*/ + 5 + 1);
+
+ // Verify input, output, and subset of edges in fdef.
+ // The subset of edges we verify is a chain between feed1 and output to
+ // make sure that the correct output is picked.
+ tensorflow::FunctionDef fdef;
+ ASSERT_TRUE(GetFunctionDef(func_, &fdef));
+ VerifyFDefInputs(fdef, M({{"feed1"}, {"feed2"}}));
+ VerifyFDefOutputs(fdef, M({{"test_loop_exit"}}));
+ VerifyFDefEdges(fdef,
+ {{"feed1", "test_loop/Enter:0"},
+ {"test_loop/Enter:output:0", "test_loop/Merge:0"},
+ {"test_loop/Merge:output:0", "test_loop/Switch:0"},
+ {"test_loop/Switch:output_false:0", "test_loop/Exit:0"},
+ {"test_loop/Exit:output:0", "test_loop_exit"}},
+ {}, false);
+}
+
+TEST_F(CApiFunctionTest, ControlDependency) {
+ /*
+ * | | scalar
+ * | | .
+ * v v . <---- control dependency
+ * add < -
+ * |
+ * v
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* five = ScalarConst(5, func_graph_, s_);
+ TF_Operation* add =
+ AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
+ EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ Define(-1, {}, {feed1, feed2}, {add}, nullptr);
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_);
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({two, func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
+ VerifyFDef(
+ {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
+ {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
+ {{"scalar", "add_0"}});
+}
+
+TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) {
+ /*
+ * | | scalar
+ * | | .
+ * v v . <---- control dependency
+ * add < -
+ * |
+ * v
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* five = ScalarConst(5, func_graph_, s_);
+ TF_Operation* add =
+ AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
+ EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ Define(1, {add}, {feed1, feed2}, {add}, nullptr, true);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("The source of control edge [id=3 scalar:-1 -> add:-1] "
+ "is not in the body. Encountered while creating "
+ "function 'MyFunc'"),
+ string(TF_Message(s_)));
+}
+
+TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) {
+ /*
+ * | |.
+ * | | .
+ * | | .
+ * v v . <---- control dependency
+ * add < -
+ * |
+ * v
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* add =
+ AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_);
+ EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ Define(-1, {}, {feed1, feed2}, {add}, nullptr, true);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("The source of control edge [id=3 feed1:-1 -> add:-1] "
+ "is not in the body. Encountered while creating "
+ "function 'MyFunc'"),
+ string(TF_Message(s_)));
+}
+
+TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) {
+ /*
+ * feed
+ * |
+ * +++
+ * | |
+ * +---+-+---+
+ * | | | |
+ * | v v |
+ * | add |
+ * | | |
+ * | | |
+ * +----+----+
+ * |
+ * v
+ */
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* add = Add(feed1, feed1, func_graph_, s_);
+ Define(-1, {}, {feed1, feed1}, {add}, nullptr, true);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(
+ string("TF_Output feed1:0 appears more than once in the input list"),
+ string(TF_Message(s_)));
+}
+
+TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) {
+ /*
+ * | |
+ * v v
+ * add
+ * |
+ * v
+ */
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
+ DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, nullptr, true);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does "
+ "not have output 2\n\tEncountered while processing "
+ "input 1 into function 'MyFunc'"),
+ string(TF_Message(s_)));
+}
+
+TEST_F(CApiFunctionTest, InvalidInputTensor_BadNodePtr) {
+ /*
+ * | |
+ * v v
+ * add
+ * |
+ * v
+ */
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
+ DefineT(-1, {}, {{feed1, 0}, {nullptr, 0}}, {{add, 0}}, nullptr, true);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Node is null\n\tEncountered while processing input 1 "
+ "into function 'MyFunc'"),
+ string(TF_Message(s_)));
+}
+
+TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) {
+ /*
+ * | |
+ * v v
+ * add
+ * |
+ * v
+ */
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
+ DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, nullptr, true);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does "
+ "not have output 3\n\tEncountered while processing "
+ "output 0 from function 'MyFunc'"),
+ string(TF_Message(s_)));
+}
+
+TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) {
+ /*
+ * | |
+ * v v
+ * add
+ * |
+ * v
+ */
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ Add(feed1, feed2, func_graph_, s_);
+ DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{nullptr, 3}}, nullptr, true);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 "
+ "from function 'MyFunc'"),
+ string(TF_Message(s_)));
+}
+
+TEST_F(CApiFunctionTest, NodeMissingInput) {
+ /*
+ * input---> | | <----missing input
+ * v v
+ * body----> add
+ * |
+ * v
+ */
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
+ DefineT(1, {add}, {{feed1, 0}}, {{add, 0}}, nullptr, true);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Input 1, 'feed2:0', of node 'add' in function 'MyFunc' "
+ "is not available. You might need to include it in inputs "
+ "or include its source node in the body"),
+ string(TF_Message(s_)));
+}
+
+TEST_F(CApiFunctionTest, OutputOpNotInBody) {
+ /*
+ * | |
+ * v v
+ * add scalar (scalar not included in body)
+ * | |
+ * v v (function has two outputs)
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* scalar = ScalarConst(2, func_graph_, s_);
+ TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
+ Define(1, {add}, {feed1, feed2}, {add, scalar}, nullptr, true);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("TF_Output scalar:0 is neither in the function body nor "
+ "among function inputs. Encountered while creating "
+ "function 'MyFunc'"),
+ string(TF_Message(s_)));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index f7d25dce8f..6e44a72e2b 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -130,6 +130,11 @@ struct TF_DeviceList {
std::vector<tensorflow::DeviceAttributes> response;
};
+struct TF_Function {
+ // Currently contains a single function and no gradients
+ tensorflow::FunctionDefLibrary fdef_lib;
+};
+
namespace tensorflow {
class TensorCApi {
@@ -142,6 +147,9 @@ class TensorCApi {
};
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
+
+Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);
+
} // end namespace tensorflow
#endif // TENSORFLOW_C_C_API_INTERNAL_H_
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 0aa60fb45d..c442029009 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -829,7 +829,7 @@ TEST(CAPI, ShapeInferenceError) {
TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3");
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_Operation* add = Add(vec2, vec3, graph, status);
+ TF_Operation* add = AddNoCheck(vec2, vec3, graph, status);
ASSERT_NE(TF_OK, TF_GetCode(status));
ASSERT_TRUE(add == nullptr);
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index 21603c1a07..9cd978c97e 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/c/c_test_util.h"
+#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
using tensorflow::GraphDef;
@@ -36,6 +38,23 @@ TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
return t;
}
+TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
+ const int32_t* values) {
+ int64_t num_values = 1;
+ for (int i = 0; i < num_dims; ++i) {
+ num_values *= dims[i];
+ }
+ TF_Tensor* t =
+ TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values);
+ memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values);
+ return t;
+}
+
+TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) {
+ int64_t dims = values.size();
+ return Int32Tensor(&dims, 1, values.data());
+}
+
TF_Tensor* Int32Tensor(int32_t v) {
const int num_bytes = sizeof(int32_t);
int32_t* values = new int32_t[1];
@@ -44,19 +63,40 @@ TF_Tensor* Int32Tensor(int32_t v) {
&Int32Deallocator, nullptr);
}
-TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
+// All the *Helper methods are used as a workaround for the restrictions that
+// one cannot call ASSERT_* methods in non-void-returning functions (when
+// exceptions are disabled during compilation)
+void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
+ TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
- return TF_FinishOperation(desc, s);
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
}
-TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
- const char* name) {
+TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
+ TF_Operation* op;
+ PlaceholderHelper(graph, s, name, &op);
+ return op;
+}
+
+void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name,
+ TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", t, s);
- if (TF_GetCode(s) != TF_OK) return nullptr;
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
- return TF_FinishOperation(desc, s);
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+}
+
+TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ ConstHelper(t, graph, s, name, &op);
+ return op;
}
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
@@ -65,11 +105,39 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
return Const(tensor.get(), graph, s, name);
}
+void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s,
+ const char* name, TF_Operation** op, bool check) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
+ TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
+ TF_AddInputList(desc, add_inputs, 2);
+ *op = TF_FinishOperation(desc, s);
+ if (check) {
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+ }
+}
+
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
+ TF_Operation* op;
+ AddHelper(l, r, graph, s, name, &op, true);
+ return op;
+}
+
+TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name) {
+ TF_Operation* op;
+ AddHelper(l, r, graph, s, name, &op, false);
+ return op;
+}
+
+TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
+ TF_Graph* graph, TF_Operation* ctrl_op,
+ TF_Status* s, const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
+ TF_AddControlInput(desc, ctrl_op);
return TF_FinishOperation(desc, s);
}
@@ -81,11 +149,20 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
return TF_FinishOperation(desc, s);
}
-TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
+void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s,
+ TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
TF_Output neg_input = {n, 0};
TF_AddInput(desc, neg_input);
- return TF_FinishOperation(desc, s);
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+}
+
+TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
+ TF_Operation* op;
+ NegHelper(n, graph, s, &op);
+ return op;
}
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
@@ -96,6 +173,32 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
return TF_FinishOperation(desc, s);
}
+void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
+ const char* name, TF_Operation** op) {
+ TF_Operation* zero = ScalarConst(
+ 0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str());
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name);
+ TF_AddInput(desc, {zero, 0});
+ TF_AddInput(desc, {input, 0});
+ TF_SetAttrInt(desc, "num_split", 3);
+ TF_SetAttrType(desc, "T", TF_INT32);
+ // Set device to CPU since there is no version of split for int32 on GPU
+ // TODO(iga): Convert all these helpers and tests to use floats because
+ // they are usually available on GPUs. After doing this, remove TF_SetDevice
+ // call in c_api_function_test.cc
+ TF_SetDevice(desc, "/cpu:0");
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+}
+
+TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ Split3Helper(input, graph, s, name, &op);
+ return op;
+}
+
bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
return false;
@@ -196,6 +299,18 @@ bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
return ret;
}
+bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) {
+ TF_Status* s = TF_NewStatus();
+ TF_Buffer* buffer = TF_NewBuffer();
+ TF_FunctionToFunctionDef(func, buffer, s);
+ bool ret = TF_GetCode(s) == TF_OK;
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length);
+ TF_DeleteBuffer(buffer);
+ TF_DeleteStatus(s);
+ return ret;
+}
+
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s) {
TF_Buffer* buffer = TF_NewBuffer();
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index 0c0ba667bd..a927739d46 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -33,6 +33,13 @@ typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
// Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
+// Create a tensor with values of type TF_INT32 provided by `values`.
+TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
+ const int32_t* values);
+
+// Create 1 dimensional tensor with values from `values`
+TF_Tensor* Int32Tensor(const std::vector<int32_t>& values);
+
TF_Tensor* Int32Tensor(int32_t v);
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
@@ -47,6 +54,13 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "add");
+TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name = "add");
+
+TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
+ TF_Graph* graph, TF_Operation* ctrl_op,
+ TF_Status* s, const char* name = "add");
+
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
const char* name = "add");
@@ -54,6 +68,10 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s);
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s);
+// Split `input` along the first dimention into 3 tensors
+TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
+ const char* name = "split3");
+
bool IsPlaceholder(const tensorflow::NodeDef& node_def);
bool IsScalarConst(const tensorflow::NodeDef& node_def, int v);
@@ -66,6 +84,8 @@ bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def);
+bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def);
+
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s);
diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake
index 87d946c346..c5a1018127 100644
--- a/tensorflow/contrib/cmake/tf_c.cmake
+++ b/tensorflow/contrib/cmake/tf_c.cmake
@@ -18,6 +18,7 @@
set(tf_c_srcs
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/c_api.h"
+ "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc"
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 7d938365c5..a274c79970 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -523,6 +523,17 @@ Status Graph::IsValidNode(const Node* node) const {
return Status::OK();
}
+Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
+ TF_RETURN_IF_ERROR(IsValidNode(node));
+ if (idx >= node->num_outputs()) {
+ return errors::InvalidArgument("Node '", node->name(), "' (type: '",
+ node->op_def().name(),
+ "', num of outputs: ", node->num_outputs(),
+ ") does not have ", "output ", idx);
+ }
+ return Status::OK();
+}
+
Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
const Node* cost_node) {
Node* node = nullptr;
@@ -572,7 +583,7 @@ int Graph::InternDeviceName(const string& device_name) {
}
string Edge::DebugString() const {
- return strings::Printf("Edge %d %s:%d -> %s:%d", id_, src_->name().c_str(),
+ return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
src_output_, dst_->name().c_str(), dst_input_);
}
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 51ede642d2..25875185e4 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -519,6 +519,10 @@ class Graph {
// Returns OK if `node` is non-null and belongs to this graph
Status IsValidNode(const Node* node) const;
+ // Returns OK if IsValidNode(`node`) and `idx` is less than
+ // node->num_outputs()
+ Status IsValidOutputTensor(const Node* node, int idx) const;
+
// TODO(josh11b): uint64 hash() const;
private:
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 08dd3922db..fa49e66e87 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -373,6 +373,33 @@ def TF_Reset(target, containers=None, config=None):
TF_DeleteSessionOptions(opts)
%}
+// We use TF_GraphToFunction_wrapper instead of TF_GraphToFunction
+%ignore TF_GraphToFunction;
+// TF_GraphToFunction_wrapper does not use any Python methods and
+// does not require GIL to be held.
+%unignore TF_GraphToFunction_wrapper;
+
+// $input is a Python list of wrapped TF_Operations
+%typemap(in) (const std::vector<TF_Operation*>* opers)
+ (std::vector<TF_Operation*> opers) {
+ if ($input != Py_None) {
+ if (!PyList_Check($input)) {
+ SWIG_exception_fail(SWIG_TypeError, "$symname: expected list");
+ }
+ size_t size = PyList_Size($input);
+ for (int i = 0; i < size; ++i) {
+ PyObject* item = PyList_GetItem($input, i);
+ TF_Operation* oper_ptr;
+ SWIG_ConvertPtr(item, reinterpret_cast<void**>(&oper_ptr),
+ $descriptor(TF_Operation*), 0);
+ opers.push_back(oper_ptr);
+ }
+ $1 = &opers;
+ } else {
+ $1 = nullptr;
+ }
+}
+
%include "tensorflow/python/client/tf_session_helper.h"
%unignoreall
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index 60a589fa8b..72f560fa87 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -337,4 +337,38 @@ std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
return control_inputs;
}
+TF_Function* TF_GraphToFunction_wrapper(const TF_Graph* fn_body,
+ const char* fn_name,
+ const std::vector<TF_Operation*>* opers,
+ const std::vector<TF_Output>& inputs,
+ const std::vector<TF_Output>& outputs,
+ const NameVector& output_names,
+ const TF_FunctionOptions* opts,
+ TF_Status* out_status) {
+ if (!output_names.empty() && output_names.size() != outputs.size()) {
+ Set_TF_Status_from_Status(
+ out_status,
+ errors::InvalidArgument(
+ "output names must be either empty or equal in size to outputs. ",
+ "output names size = ", output_names.size(),
+ " outputs size = ", outputs.size()));
+ return nullptr;
+ }
+
+ int nopers = -1;
+ const TF_Operation* const* opers_array = nullptr;
+ if (opers != nullptr) {
+ nopers = opers->size();
+ opers_array = opers->data();
+ }
+
+ const char** output_names_ptr =
+ output_names.empty() ? nullptr
+ : const_cast<const char**>(output_names.data());
+
+ return TF_GraphToFunction(fn_body, fn_name, nopers, opers_array,
+ inputs.size(), inputs.data(), outputs.size(),
+ outputs.data(), output_names_ptr, opts, out_status);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 3bc63f822f..8fae6206c0 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -148,6 +148,16 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
TF_Operation* oper);
+// `opers` equaling NULL are converted to `nopers = -1`.
+// `output_names` must be empty or have the same length as `outputs`.
+TF_Function* TF_GraphToFunction_wrapper(const TF_Graph* fn_body,
+ const char* fn_name,
+ const std::vector<TF_Operation*>* opers,
+ const std::vector<TF_Output>& inputs,
+ const std::vector<TF_Output>& outputs,
+ const NameVector& output_names,
+ const TF_FunctionOptions* opts,
+ TF_Status* out_status);
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 2f35f0e04b..7a866ee6e8 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -26,7 +26,9 @@ import hashlib
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import op_def_pb2
+from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -290,6 +292,7 @@ class _DefinedFunction(object):
self._shape_func = shape_func
self._extra_kwargs = kwargs
self._definition = None # Constructed lazily.
+ self._c_func = None # Constructed with definition.
self._sub_functions = dict() # Constructed with definition.
self._args = []
@@ -396,6 +399,22 @@ class _DefinedFunction(object):
if self._func.__doc__:
self._definition.signature.description = self._func.__doc__
+ # pylint: disable=protected-access
+ if temp_graph._c_graph:
+ with errors.raise_exception_on_not_ok_status() as status:
+ output_names = ([compat.as_bytes(x) for x in self._out_names]
+ if self._out_names else [])
+ self._c_func = c_api.TF_GraphToFunction_wrapper(
+ temp_graph._c_graph,
+ self._func_name,
+ None, # opers
+ [t._as_tf_output() for t in inputs],
+ [t._as_tf_output() for t in outputs],
+ output_names,
+ None, # opts
+ status)
+ # pylint: enable=protected-access
+
def _create_hash_str(self, input_arg, output_arg, node_def):
"""Creates an 8-character string unique to this input.
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 589db9ef4d..40205ddf05 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
@@ -63,7 +64,51 @@ def _OptimizerOptions():
do_constant_folding=cfold)))
-class FunctionTest(test.TestCase):
+class FunctionTestMethods(object):
+ """Test methods for verifying Function support.
+
+ These test methods are used as mix-ins in two test cases: with
+ and without C API support.
+ """
+
+ def testIdentity(self):
+
+ @function.Defun(dtypes.float32, func_name="MyIdentity")
+ def MyIdentityFunc(a):
+ return a
+
+ with ops.Graph().as_default():
+ call = MyIdentityFunc([18.0])
+ self.assertEqual("MyIdentity", call.op.name)
+ with session.Session() as sess:
+ self.assertAllEqual([18.0], sess.run(call))
+
+ def testIdentityOutputName(self):
+
+ @function.Defun(
+ dtypes.float32, func_name="MyIdentity", out_names=["my_result_name"])
+ def MyIdentityFunc(a):
+ return a
+
+ with ops.Graph().as_default():
+ call = MyIdentityFunc([18.0])
+ self.assertEqual("MyIdentity", call.op.name)
+ with session.Session() as sess:
+ self.assertAllEqual([18.0], sess.run(call))
+
+ def testTooManyOutputNames(self):
+
+ @function.Defun(
+ dtypes.float32, func_name="MyIdentity",
+ out_names=["my_result1", "my_result2"])
+ def MyIdentityFunc(a):
+ return a
+
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError, (r"Length of out_names \(2\) does not match number of "
+ r"outputs \(1\): my_result1, my_result2")):
+ MyIdentityFunc([18.0])
def testDefineFunction2Args(self):
@@ -77,6 +122,35 @@ class FunctionTest(test.TestCase):
with session.Session() as sess:
self.assertAllEqual([5.0], sess.run(call))
+ def testValueErrorOnFunctionWithNoOutput(self):
+ # TODO(iga): Remove this restriction and this test
+
+ @function.Defun(dtypes.float32, dtypes.float32)
+ def APlus2B(a, b):
+ print(a + b * 2) # Create some ops to have nodes in the body
+ # Using 'print' to make lint happy
+
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(ValueError,
+ "Function can not return None"):
+ APlus2B([1.0], [2.0])
+
+ def testDefineFunction2ArgsOutputName(self):
+
+ @function.Defun(
+ dtypes.float32,
+ dtypes.float32,
+ func_name="APlus2B",
+ out_names=["my_result_name"])
+ def APlus2B(a, b):
+ return a + b * 2
+
+ with ops.Graph().as_default():
+ call = APlus2B([1.0], [2.0])
+ self.assertEqual("APlus2B", call.op.name)
+ with session.Session() as sess:
+ self.assertAllEqual([5.0], sess.run(call))
+
def testDefineFunctionDuplicateOutputs(self):
@function.Defun(dtypes.float32, func_name="Duplicate")
@@ -137,6 +211,7 @@ class FunctionTest(test.TestCase):
out, = sess.run(dx, feed)
self.assertAllClose(1 - np.square(np.tanh(inp)), out)
+ @test_util.disable_c_api # Function gradients don't work with C API
def testCustomGradient(self):
dtype = dtypes.float32
@@ -169,6 +244,7 @@ class FunctionTest(test.TestCase):
out, = sess.run(dlogits, {logits: x, labels: y})
self.assertAllClose(out, np.exp(prob - y))
+ @test_util.disable_c_api # Function gradients don't work with C API
def testCustomGradientError(self):
dtype = dtypes.float32
@@ -194,6 +270,7 @@ class FunctionTest(test.TestCase):
"SymGrad expects to return 1.*but get 2.*instead"):
_ = sess.run(dinp, {inp: x})
+ @test_util.disable_c_api # Function gradients don't work with C API
def testSymGradShape(self):
g = ops.Graph()
with g.as_default():
@@ -209,6 +286,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(x.get_shape(), dx.get_shape())
self.assertEqual(y.get_shape(), dy.get_shape())
+ @test_util.disable_c_api # Function gradients don't work with C API
def testSymGradAttr(self):
@function.Defun(noinline=True)
@@ -312,6 +390,7 @@ class FunctionTest(test.TestCase):
"assertion failed.*-3"):
self.assertAllEqual(Foo(constant_op.constant(-3.0)).eval(), 6.0)
+ @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API
def testAssertWrapper(self):
@function.Defun(dtypes.float32)
@@ -326,6 +405,7 @@ class FunctionTest(test.TestCase):
"assertion"):
_ = MyFn(100.0).eval()
+ @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API
def testWhileLoopCallsFunc(self):
with self.test_session(use_gpu=True) as sess:
@@ -345,6 +425,7 @@ class FunctionTest(test.TestCase):
ans = sess.run(loop)
self.assertAllClose(ans, 131072.)
+ @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API
def testControlFlowStrictness(self):
"""Inlined functions must not execute in a untaken control flow branch."""
@@ -607,6 +688,7 @@ class FunctionTest(test.TestCase):
self.assertAllClose(vals[0], vals[1])
self.assertAllClose(vals[2], vals[3])
+ @test_util.disable_c_api # Function Declaration doesn't work with C API
def testDeclare(self):
foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
dtypes.float32)])
@@ -626,6 +708,7 @@ class FunctionTest(test.TestCase):
expected = rand * rand + 1.0
self.assertAllClose(expected, y.eval(feed_dict={x: rand}))
+ @test_util.disable_c_api # Function Declaration doesn't work with C API
def testDeclareUsedInDefun(self):
foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
dtypes.float32)])
@@ -649,6 +732,7 @@ class FunctionTest(test.TestCase):
expected = rand * rand + 1.0
self.assertAllClose(expected, y.eval(feed_dict={x: rand}))
+ @test_util.disable_c_api # Function Declaration doesn't work with C API
def testDeclareTypeMistake(self):
foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
dtypes.float32)])
@@ -861,6 +945,32 @@ class FunctionTest(test.TestCase):
self.assertEqual(len(f.signature.input_arg), 3)
+class FunctionTest(FunctionTestMethods, test.TestCase):
+ """Test case that invokes test methods with _USE_C_API=False."""
+
+ def setUp(self):
+ self.prev_use_c_api = ops._USE_C_API
+ ops._USE_C_API = False
+ super(FunctionTest, self).setUp()
+
+ def tearDown(self):
+ ops._USE_C_API = self.prev_use_c_api
+ super(FunctionTest, self).tearDown()
+
+
+class FunctionWithCApiTest(FunctionTestMethods, test.TestCase):
+ """Test case that invokes test methods with _USE_C_API=True."""
+
+ def setUp(self):
+ self.prev_use_c_api = ops._USE_C_API
+ ops._USE_C_API = True
+ super(FunctionWithCApiTest, self).setUp()
+
+ def tearDown(self):
+ ops._USE_C_API = self.prev_use_c_api
+ super(FunctionWithCApiTest, self).tearDown()
+
+
class FunctionsFromProtos(test.TestCase):
def expectFunctionsEqual(self, func, grad_func=None, new_func=None):
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index ccaa2141b5..659bc394b9 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2948,6 +2948,14 @@ class Graph(object):
if self._graph_def_versions.min_consumer < 12:
self._graph_def_versions.min_consumer = 12
self._functions[name] = function
+ if self._c_graph:
+ # pylint: disable=protected-access
+ assert function._c_func, (
+ "Cannot add function created without C API support to graph "
+ "created with C API support")
+ with errors.raise_exception_on_not_ok_status() as status:
+ c_api.TF_GraphAddFunction(self._c_graph, function._c_func, status)
+ # pylint: enable=protected-access
@property
def building_function(self):