diff options
author | Igor Ganichev <iga@google.com> | 2017-08-30 21:05:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-30 21:08:53 -0700 |
commit | 9624d165f1f2c717eda96464fee8bf7229cc14f5 (patch) | |
tree | 8024d708b58b0c78f19d4c3cfc9f7c4b0c24b70c /tensorflow | |
parent | 424aa9aa9559f6fa29d8ccf3d74ff25528b39209 (diff) |
Add function support to Tensorflow C API
This change adds minimal functionality. Support for FunctionOptions,
attributes, output name rewriting, function name generation, etc is
comming next.
PiperOrigin-RevId: 167091238
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/c/BUILD | 24 | ||||
-rw-r--r-- | tensorflow/c/c_api.cc | 37 | ||||
-rw-r--r-- | tensorflow/c/c_api.h | 116 | ||||
-rw-r--r-- | tensorflow/c/c_api_function.cc | 496 | ||||
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 1039 | ||||
-rw-r--r-- | tensorflow/c/c_api_internal.h | 8 | ||||
-rw-r--r-- | tensorflow/c/c_api_test.cc | 2 | ||||
-rw-r--r-- | tensorflow/c/c_test_util.cc | 131 | ||||
-rw-r--r-- | tensorflow/c/c_test_util.h | 20 | ||||
-rw-r--r-- | tensorflow/contrib/cmake/tf_c.cmake | 1 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.h | 4 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session.i | 27 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.cc | 34 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.h | 10 | ||||
-rw-r--r-- | tensorflow/python/framework/function.py | 19 | ||||
-rw-r--r-- | tensorflow/python/framework/function_test.py | 112 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 8 |
18 files changed, 2072 insertions, 29 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 604dfab148..1822e235eb 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -45,8 +45,13 @@ tf_cuda_library( tf_cuda_library( name = "c_api", - srcs = ["c_api.cc"], - hdrs = ["c_api.h"], + srcs = [ + "c_api.cc", + "c_api_function.cc", + ], + hdrs = [ + "c_api.h", + ], copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ @@ -158,6 +163,21 @@ tf_cc_test( ) tf_cc_test( + name = "c_api_function_test", + size = "small", + srcs = ["c_api_function_test.cc"], + deps = [ + ":c_api", + ":c_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( name = "while_loop_test", size = "small", srcs = ["while_loop_test.cc"], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 07c8277a6f..c454c94249 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -165,22 +165,6 @@ void deallocate_buffer(void* data, size_t len, void* arg) { tensorflow::cpu_allocator()->DeallocateRaw(data); } -Status MessageToBuffer(const tensorflow::protobuf::Message& in, - TF_Buffer* out) { - if (out->data != nullptr) { - return InvalidArgument("Passing non-empty TF_Buffer is invalid."); - } - const auto proto_size = in.ByteSizeLong(); - void* buf = tensorflow::port::Malloc(proto_size); - in.SerializeToArray(buf, proto_size); - out->data = buf; - out->length = proto_size; - out->data_deallocator = [](void* data, size_t length) { - tensorflow::port::Free(data); - }; - return Status::OK(); -} - } // namespace TF_Tensor::~TF_Tensor() { buffer->Unref(); } @@ -559,6 +543,27 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, dimvec.size(), base, size, DeleteArray, base); } +Status MessageToBuffer(const tensorflow::protobuf::Message& in, + TF_Buffer* out) { + if (out->data != nullptr) { + return InvalidArgument("Passing non-empty TF_Buffer is invalid."); + } + const size_t proto_size = in.ByteSizeLong(); + void* buf = tensorflow::port::Malloc(proto_size); + if (buf == nullptr) { + return tensorflow::errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '", + in.GetTypeName(), "' and size ", proto_size); + } + in.SerializeToArray(buf, proto_size); + out->data = buf; + out->length = proto_size; + out->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; + return Status::OK(); +} + // Helpers for loading a TensorFlow plugin (a .so file). Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 43b5078013..ee110d88ce 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -357,6 +357,14 @@ typedef struct TF_Output { int index; // The index of the output within oper. } TF_Output; +// TF_Function is a grouping of operations with defined inputs and outputs. +// Once created and added to graphs, functions can be invoked by creating an +// operation whose operation type matches the function name. +typedef struct TF_Function TF_Function; + +// Function definition options. TODO(iga): Define and implement +typedef struct TF_FunctionOptions TF_FunctionOptions; + // Sets the shape of the Tensor referenced by `output` in `graph` to // the shape described by `dims` and `num_dims`. // @@ -914,6 +922,15 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status); +// Add `function` to graph `g`. Once `function` is added to `g`, +// it can be called by creating an operation using the function's name. +// +// If successful, status is set to OK and function is added to g +// Otherwise, status is set to the encountered error and g is unmodified +TF_CAPI_EXPORT extern void TF_GraphAddFunction(TF_Graph* g, + const TF_Function* function, + TF_Status* status); + // Note: The following function may fail on very large protos in the future. TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, @@ -1001,6 +1018,105 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy); +// Create a TF_Function from a TF_Graph +// +// Params: +// fn_body - the graph whose operations (or subset of whose operations) will be +// converted to TF_Function. +// fn_name - the name of the new TF_Function. Should match the operation +// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct +// from other operation names (at least those registered in graphs +// where this function will be used). +// TODO(iga): Allow null in here and have C API come up with +// a unique name with high probability (similarly to +// _create_hash_str in function.py) +// num_opers - `num_opers` contains the number of elements in the `opers` array +// or a special value of -1 meaning that no array is given. +// The distinction between an empty array of operations and no +// array of operations is necessary to distinguish the case of +// creating a function with no body (e.g. identity or permutation) +// and the case of creating a function whose body contains all +// the nodes in the graph (except for the automatic skipping, see +// below). +// opers - Array of operations to become the body of the function or null. +// - If no array is given (`num_opers` = -1), all the +// operations in `fn_body` will become part of the function +// except operations referenced in `inputs`. These operations +// must have a single output (these operations are typically +// placeholders created for the sole purpose of representing +// an input. We can relax this constraint if there are +// compelling use cases). +// - If an array is given (`num_opers` >= 0), all operations +// in it will become part of the function. In particular, no +// automatic skipping of dummy input operations is performed. +// ninputs - number of elements in `inputs` array +// inputs - array of TF_Outputs that specify the inputs to the function. +// If `ninputs` is zero (the function takes no inputs), `inputs` +// can be null. The names used for function inputs are normalized +// names of the operations (usually placeholders) pointed to by +// `inputs`. These operation names should start with a letter. +// Normalization will convert all letters to lowercase and +// non-alphanumeric characters to '_' to make resulting names match +// the "[a-z][a-z0-9_]*" pattern for operation argument names. +// `inputs` cannot contain the same tensor twice. +// noutputs - number of elements in `outputs` array +// outputs - array of TF_Outputs that specify the outputs of the function. +// If `noutputs` is zero (the function returns no outputs), `outputs` +// can be null. `outputs` can contain the same tensor more than once. +// output_names - The names of the function's outputs. `output_names` array +// must either have the same length as `outputs` +// (i.e. `noutputs`) or be null. In the former case, +// the names should match the regular expression for ArgDef +// names - "[a-z][a-z0-9_]*". In the latter case, +// names for outputs will be generated automatically. +// opts - various options for the function, e.g. XLA's inlining control. +// status - Set to OK on success and an appropriate error on failure. +// +// Note that when the same TF_Output is listed as both an input and an output, +// the corresponding function's output will equal to this input, +// instead of the original node's output. +// +// Callers must also satisfy the following constraints: +// - `inputs` cannot refer to TF_Outputs within a control flow context. For +// example, one cannot use the output of "switch" node as input. +// - No TF_Output of a function (inside any of `inputs`, `outputs`, `fn_body`) +// is allowed to have a reference type. Reference types are not exposed +// through C API and are being deprecated. +// - Every node in the function's body must have all of its inputs (including +// control inputs). In other words, for every node in the body, each input +// must be either listed in `inputs` or must come from another node in +// the body. In particular, it is an error to have a control edge going from +// a node outside of the body into a node in the body. This applies to control +// edges going from nodes referenced in `inputs` to nodes in the body when +// the former nodes are not in the body (automatically skipped or not +// included in explicitly specified body). +// +// Returns: +// On successful, a newly created TF_Function instance. It must be deleted by +// calling TF_DeleteFunction. +// +// On failure, null. +// +// TODO(iga): Add input_names argument and get output_names working (they are +// currently ignored) +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( + const TF_Graph* fn_body, const char* fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + const TF_FunctionOptions* opts, TF_Status* status); + +// Write out a serialized representation of `func` (as a FunctionDef protocol +// message) to `output_func_def` (allocated by TF_NewBuffer()). +// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() +// is called. +// +// May fail on very large graphs in the future. +TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func, + TF_Buffer* output_func_def, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function*); + // TODO(josh11b): Register OpDef, available to all operations added // to this graph. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc new file mode 100644 index 0000000000..b4c6397d0b --- /dev/null +++ b/tensorflow/c/c_api_function.cc @@ -0,0 +1,496 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api_internal.h" + +#include <algorithm> +#include <unordered_map> +#include <unordered_set> + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace { + +// Class that maintains a one-to-one original node name -> new node name +// mapping. We normalize the names used as input and output arguments to match +// regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name. +// Once we rename them, we risk creating a name collision with the other +// node names, so if necessary we add a suffix to make +// names unique. If we have an input named "A" and a node in the function +// body named "a", they will be renamed to "a" and "a_0". +class NodeNameMapping { + public: + NodeNameMapping() = default; + + // Normalize the input/output name and make it unique. + string GetIOName(const string& name); + + // Make the node name unique. + string Uniquify(const string& name); + + // Look up how a node name was previously normalized/uniquified. + // Returns empty if name was never seen. + string Lookup(const string& name) const; + + private: + string UniquifyHelper(const string& name) const; + static string Normalize(string name); + + // The normalized/uniquified names already used as + // input names (in signature), output names (in signature), and node names + // (in node_def). + // This is a superset of values in name_mapping_. + std::unordered_set<string> used_names_; + // Mapping from original node name from the graph to the normalized + // and uniqified version of it. + std::unordered_map<string, string> name_mapping_; +}; + +string NodeNameMapping::Normalize(string name) { + // Convert letters to lowercase and non-alphanumeric characters to '_'. + if (name.empty()) return "unknown"; + const int n = name.size(); + for (int i = 0; i < n; ++i) { + char c = name[i]; + if (isalnum(c)) { + if (isupper(c)) { + name[i] = tolower(c); + } + } else { + name[i] = '_'; + } + } + + // Find the first letter and start with it. + int i = 0; + for (; i < n; ++i) { + if (isalpha(name[i])) break; + } + + // Return "unknown" if none of the name's chars were letters. + return i == n ? "unknown" : name.substr(i); +} + +string NodeNameMapping::UniquifyHelper(const string& name) const { + // If the name hasn't been used yet, use it as-is. + if (used_names_.find(name) == used_names_.end()) return name; + // Add a suffix to name to make it unique. + for (int i = 0;; ++i) { + const string candidate = strings::StrCat(name, "_", i); + if (used_names_.find(candidate) == used_names_.end()) return candidate; + } +} + +string NodeNameMapping::GetIOName(const string& name) { + const string& input_name = UniquifyHelper(Normalize(name)); + // Record that we used this name, but don't add it to name_mapping_ + // since this name is not for a node. + used_names_.insert(input_name); + return input_name; +} + +string NodeNameMapping::Uniquify(const string& name) { + const string uniqued = UniquifyHelper(name); + name_mapping_[name] = uniqued; + used_names_.insert(uniqued); + return uniqued; +} + +string NodeNameMapping::Lookup(const string& name) const { + const auto iter = name_mapping_.find(name); + if (iter == name_mapping_.end()) return string(); + return iter->second; +} + +Status ValidateNoRefOutputs(const Node* node) { + for (int i = 0; i < node->num_outputs(); ++i) { + const DataType& dt = node->output_type(i); + if (IsRefType(dt)) { + return errors::InvalidArgument("Output ", i, " of node '", node->name(), + "' has a reference " + "type ", + DataTypeString(dt)); + } + } + return Status::OK(); +} + +Status FillFunctionBody( + const string& fn_name, const NodeNameMapping& node_names, + const std::vector<const Node*>& body_nodes, + const std::unordered_map<string, string>& tensor_renaming, + FunctionDef* fdef) { + std::vector<const Edge*> in_edges; + std::vector<const Edge*> control_edges; + for (const Node* node : body_nodes) { + NodeDef* node_def = fdef->add_node_def(); + // First, copy the node_def as is. We will patch it next. + *node_def = node->def(); + if (!node->assigned_device_name().empty()) { + node_def->set_device(node->assigned_device_name()); + } + node_def->set_name(node_names.Lookup(node->name())); + + // Input names must be set based on nested names in tensor_renaming. + // Clear the flat input names we got from the original node_def + // from the graph. + node_def->clear_input(); + + // Collect regular and control inputs. Regular inputs are indexed + // by the index at which they come into the `node`. Control inputs + // don't follow any order. + in_edges.clear(); + in_edges.resize(node->num_inputs(), nullptr); + control_edges.clear(); + for (const Edge* edge : node->in_edges()) { + if (edge->src()->IsSource()) continue; + if (edge->IsControlEdge()) { + control_edges.push_back(edge); + } else { + in_edges[edge->dst_input()] = edge; + } + } + + // Add regular inputs. + for (size_t i = 0; i < in_edges.size(); ++i) { + const Edge* edge = in_edges[i]; + string original_input_name; + if (edge == nullptr) { + // A backedge might not appear as a regular Edge, but be only present + // in the node_def. Such edges are referred to as requested_inputs(). + if (i >= node->requested_inputs().size()) { + return errors::InvalidArgument( + "Graph to be converted to function appears to be malformed. ", + "Node ", node->name(), " is missing input edge ", i); + } + original_input_name = + ParseTensorName(node->requested_inputs()[i]).ToString(); + } else { + original_input_name = + strings::StrCat(edge->src()->name(), ":", edge->src_output()); + } + + const auto iter = tensor_renaming.find(original_input_name); + if (iter == tensor_renaming.end()) { + return errors::InvalidArgument( + "Input ", i, ", '", original_input_name, "', of node '", + node->name(), "' in function '", fn_name, + "' is not available. You might need to include it in inputs " + "or include its source node in the body"); + } + node_def->add_input(iter->second); + } + + // Add control inputs. + for (const Edge* edge : control_edges) { + // Add this control input only if the src node is in the body. + const string normalized = node_names.Lookup(edge->src()->name()); + // If we did not find a name for the source of control edge, this + // source must be outside of the body. Raise an error. + if (normalized.empty()) { + return errors::InvalidArgument( + "The source of control edge ", edge->DebugString(), + " is not in the body. Encountered while creating function '", + fn_name, "'"); + } + node_def->add_input(strings::StrCat("^", normalized)); + } + } + return Status::OK(); +} + +// Graph to FunctionDef conversion. This code is closely modeled on the Python +// code in third_party/tensorflow/python/framework/function.py. +Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, + const std::vector<const Node*>& body_nodes, + const std::vector<OutputTensor>& inputs, + const std::vector<OutputTensor>& outputs, + const std::vector<string>& output_names, + FunctionDef* fdef) { + fdef->mutable_signature()->set_name(fn_name); + + // Keep track of names we used and how we normalized them. + NodeNameMapping node_names; + + // Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the + // name we used in the function: + // - For input tensors: + // {flat_tensor_name -> normalized_name_of_src_node} + // e.g. {In:3 -> in} + // - For tensors produced by nodes in function's body: + // {flat_tensor_name -> nested_tensor_name} + // e.g. {Add:3 -> add_0:z:1} + std::unordered_map<string, string> tensor_renaming; + + // Fill inputs in function's signature. + for (size_t i = 0; i < inputs.size(); ++i) { + const Node* node = inputs[i].node; + int idx = inputs[i].index; + OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg(); + argdef->set_type(node->output_type(idx)); + const string& input_name = node_names.GetIOName(node->name()); + argdef->set_name(input_name); + tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name; + } + + // Fill outputs in function's signature. + for (size_t i = 0; i < outputs.size(); ++i) { + const Node* node = outputs[i].node; + int idx = outputs[i].index; + OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg(); + argdef->set_type(node->output_type(idx)); + argdef->set_name(node_names.GetIOName(node->name())); + } + + // Populate tensor_renaming and node_names. + // Generate the new output names for every node in the function. + // The NodeDefs in FunctionDefs use a different naming scheme for + // their inputs than the NodeDefs in a graph (see the comment for + // FunctionDef.node_def in function.proto). We do the + // graph tensor name -> function tensor name conversion for every + // possible input (i.e. every node's outputs) and store the result + // in tensor_renaming. + for (const Node* node : body_nodes) { + // Make sure node_name does not collide with an input or output name. + const string& node_name = node_names.Uniquify(node->name()); + // For each output_arg in the op_def, the output_ranges + // map will have [start, end] range of indices that this arg produces + // among all the output tensors of this op. + NameRangeMap output_ranges; + TF_RETURN_IF_ERROR( + NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); + for (const auto& output : output_ranges) { + const string& output_name = output.first; + int index_start = output.second.first; + int index_end = output.second.second; + for (int i = index_start; i < index_end; ++i) { + const string& original_name = strings::StrCat(node->name(), ":", i); + const string& new_name = + strings::StrCat(node_name, ":", output_name, ":", i - index_start); + // Record the mapping if this tensor is not already mapped. + // Tensor can be already mapped if it is used as an input. + if (tensor_renaming.find(original_name) == tensor_renaming.end()) { + tensor_renaming[original_name] = new_name; + } + } + } + } + + TF_RETURN_IF_ERROR( + FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef)); + + // Remap return values. + for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { + const string& ret_name = fdef->signature().output_arg(r).name(); + + // We convert this flat tensor name to the nested value + // (e.g. `add:z:1`) that we stored in tensor_renaming. + const string& return_value = + strings::StrCat(outputs[r].node->name(), ":", outputs[r].index); + const auto iter = tensor_renaming.find(return_value); + if (iter == tensor_renaming.end()) { + return errors::InvalidArgument( + "TF_Output ", return_value, " is neither in the function body ", + "nor among function inputs. Encountered while creating function '", + fn_name, "'"); + } + (*fdef->mutable_ret())[ret_name] = iter->second; + } + + return Status::OK(); +} + +// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and +// does various checks while doing so. `input_nodes` will contain the same +// information as input_tensors just in a different structure to make +// following processing easier. TODO(iga): Simplify this nested structure. +Status ProcessInputs( + const TF_Graph* fn_body, const char* fn_name, int ninputs, + const TF_Output* inputs, std::vector<OutputTensor>* input_tensors, + std::unordered_map<const Node*, std::vector<int>>* input_nodes) + EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { + input_tensors->reserve(ninputs); + for (int i = 0; i < ninputs; ++i) { + const Node& node = inputs[i].oper->node; + int idx = inputs[i].index; + + TF_RETURN_WITH_CONTEXT_IF_ERROR( + fn_body->graph.IsValidOutputTensor(&node, idx), + "Encountered while processing input ", i, " into function '", fn_name, + "'"); + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(&node), + "Encountered while processing input ", i, + " into function '", fn_name, "'"); + + input_tensors->emplace_back(&node, idx); + + const auto& iter = input_nodes->find(&node); + if (iter == input_nodes->end()) { + input_nodes->insert({&node, {idx}}); + } else { + auto& indices = iter->second; + if (std::find(indices.begin(), indices.end(), idx) != indices.end()) { + return errors::InvalidArgument( + "TF_Output ", node.name(), ":", idx, + " appears more than once in the input list"); + } + indices.push_back(idx); + } + } + return Status::OK(); +} + +// Converts `noutputs` and `outputs` into `outputs_tensors` and does various +// checks while doing so. +Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, + int noutputs, const TF_Output* outputs, + std::vector<OutputTensor>* output_tensors) + EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { + output_tensors->reserve(noutputs); + for (int i = 0; i < noutputs; ++i) { + const Node& node = outputs[i].oper->node; + int idx = outputs[i].index; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + fn_body->graph.IsValidOutputTensor(&node, idx), + "Encountered while processing output ", i, " from function '", fn_name, + "'"); + output_tensors->emplace_back(&node, idx); + } + return Status::OK(); +} + +// Populates `body_nodes` with the nodes that will become function's body. +// Performs various checks. +Status ComputeBodyNodes( + const TF_Graph* fn_body, const char* fn_name, int num_opers, + const TF_Operation* const* opers, + const std::unordered_map<const Node*, std::vector<int>>& input_nodes, + std::vector<const Node*>* body_nodes) + EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { + if (num_opers == -1) { + for (const Node* node : fn_body->graph.op_nodes()) { + const auto& iter = input_nodes.find(node); + if (iter == input_nodes.end()) { + // This node is not referenced in inputs. Add it to the body. + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node), + "Encountered while creating function '", + fn_name, "'"); + body_nodes->push_back(node); + } else { + // This node is referenced in inputs. Currently, we place an + // artificial restriction and require that when num_opers=-1, such + // nodes must have a single output. + if (node->num_outputs() != 1) { + return errors::InvalidArgument( + "When `num_opers` is set to -1, nodes referenced in `inputs` " + "must have a single output. Node ", + node->name(), " has ", node->num_outputs(), + " outputs. Encountered while creating function '", fn_name, "'"); + } + } + } + } else { + body_nodes->reserve(num_opers); + for (int i = 0; i < num_opers; ++i) { + const Node* node = &opers[i]->node; + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node), + "Encountered while creating function '", + fn_name, "'"); + body_nodes->push_back(node); + } + } + return Status::OK(); +} + +} // anonymous namespace +} // namespace tensorflow + +using tensorflow::Node; +using tensorflow::string; + +TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, + int num_opers, const TF_Operation* const* opers, + int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, + const char* const* output_names, + const TF_FunctionOptions* opts, + TF_Status* status) { + tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu)); + + // Process inputs. + std::vector<tensorflow::OutputTensor> input_tensors; + std::unordered_map<const Node*, std::vector<int>> input_nodes; + status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs, + &input_tensors, &input_nodes); + if (!status->status.ok()) return nullptr; + + // Process outputs. + std::vector<tensorflow::OutputTensor> output_tensors; + status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs, + outputs, &output_tensors); + if (!status->status.ok()) return nullptr; + + // Process output names. + std::vector<string> output_names_vec; + if (output_names) { + output_names_vec.reserve(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names_vec.push_back(string(output_names[i])); + } + } + + // Compute body nodes. + std::vector<const Node*> body_nodes; + status->status = tensorflow::ComputeBodyNodes( + fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes); + if (!status->status.ok()) return nullptr; + + // Do the actual function creation. + TF_Function* tf_function = new TF_Function(); + status->status = tensorflow::GraphToFunctionDef( + fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors, + output_names_vec, tf_function->fdef_lib.add_function()); + if (!status->status.ok()) { + TF_DeleteFunction(tf_function); + return nullptr; + } + return tf_function; +} + +void TF_GraphAddFunction(TF_Graph* g, const TF_Function* function, + TF_Status* status) { + tensorflow::mutex_lock l(g->mu); + + // At the moment, we have only one function and no gradients in fdef_lib. + // This makes the following operation atomic. + // TODO(iga): Add an atomic version of AddFunctionLibrary when we support + // gradients + status->status = g->graph.AddFunctionLibrary(function->fdef_lib); +} + +void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def, + TF_Status* status) { + DCHECK_EQ(1, func->fdef_lib.function_size()); + status->status = MessageToBuffer(func->fdef_lib.function(0), output_func_def); +} + +void TF_DeleteFunction(TF_Function* function) { delete function; } diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc new file mode 100644 index 0000000000..c9dd38ea15 --- /dev/null +++ b/tensorflow/c/c_api_function_test.cc @@ -0,0 +1,1039 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api.h" + +#include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// Specification for expected input/output and its type. +// DataType value of DT_INVALID signifies that we don't want to +// check the data type. +typedef std::pair<string, DataType> IOSpec; + +std::vector<IOSpec> M(const std::initializer_list<string>& names) { + std::vector<IOSpec> v; + for (const string& name : names) { + v.push_back(IOSpec(name, DT_INVALID)); + } + return v; +} + +// Specification for an expected edge. +// src is either: +// - input name (as it appears in FunctionDef) +// - name of output tensor (in nested "add:z:0" format) +// dst is either: +// - output name (as it appears in FunctionDef) +// - <name_of_node>:<index_of_this_input_into_node> (this looks the same as +// output tensor naming, but it the index is actually an input index) +struct EdgeSpec : public std::pair<string, string> { + typedef std::pair<string, string> Base; + + // Inherit the set of constructors + using Base::pair; + + string ToString() const { return strings::StrCat(first, "->", second); } +}; + +class CApiFunctionTest : public ::testing::Test { + protected: + CApiFunctionTest() + : s_(TF_NewStatus()), + func_graph_(TF_NewGraph()), + host_graph_(TF_NewGraph()), + func_(nullptr) {} + + void SetUp() override {} + + ~CApiFunctionTest() override { + TF_DeleteFunction(func_); + TF_DeleteGraph(host_graph_); + TF_DeleteGraph(func_graph_); + TF_DeleteStatus(s_); + } + + void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs, + TF_Operation* output, int32_t expected_result) { + Run(inputs, {{output, 0}}, {expected_result}); + } + + // Run the host graph, which now contains a function and check that + // outputs are as expected. + // 'T' stands for 'tensor' since the outputs are tensors, not scalars. + void RunT(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs, + std::initializer_list<TF_Output> outputs, + const std::vector<std::vector<int32_t>>& expected_results) { + // Create a session for this graph + CSession csession(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run + csession.SetInputs(inputs); + csession.SetOutputs(outputs); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Check results + for (int i = 0; i < expected_results.size(); ++i) { + TF_Tensor* out = csession.output_tensor(i); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(1, TF_NumDims(out)); + CompareInt32Tensor(expected_results[i], out); + } + } + + // Run the host graph, which now contains a function and check that + // outputs are as expected. + void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs, + std::initializer_list<TF_Output> outputs, + const std::vector<int32_t>& expected_results) { + // Create a session for this graph. + CSession csession(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + csession.SetInputs(inputs); + csession.SetOutputs(outputs); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + for (int i = 0; i < expected_results.size(); ++i) { + TF_Tensor* out = csession.output_tensor(i); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out)); + int32_t* output_contents = static_cast<int32_t*>(TF_TensorData(out)); + EXPECT_EQ(expected_results[i], *output_contents); + } + } + + void CompareInt32Tensor(const std::vector<int32_t>& expected, TF_Tensor* t) { + int32_t* data = static_cast<int32_t*>(TF_TensorData(t)); + size_t size = TF_TensorByteSize(t); + ASSERT_EQ(expected.size() * sizeof(int32_t), size); + for (int i = 0; i < expected.size(); ++i) { + ASSERT_EQ(expected[i], data[i]) << "Different data at index " << i; + } + } + + std::vector<TF_Output> ToOutput(const std::vector<TF_Operation*> ops) { + std::vector<TF_Output> out; + for (auto op : ops) { + out.push_back({op, 0}); + } + return out; + } + + void Define(int num_opers, const std::vector<TF_Operation*>& opers, + const std::vector<TF_Operation*>& inputs, + const std::vector<TF_Operation*>& outputs, + const char** output_names, bool expect_failure = false) { + DefineT(num_opers, opers, ToOutput(inputs), ToOutput(outputs), output_names, + expect_failure); + } + + // An explicit `num_opers` is needed so that we can distinguish between the + // case of no operations specified (-1) and the case of an empty set of + // operations specified (0). + void DefineT(int num_opers, const std::vector<TF_Operation*>& opers, + const std::vector<TF_Output>& inputs, + const std::vector<TF_Output>& outputs, const char** output_names, + bool expect_failure = false) { + ASSERT_EQ(func_, nullptr); + func_ = TF_GraphToFunction(func_graph_, func_name_, num_opers, + num_opers == -1 ? nullptr : opers.data(), + inputs.size(), inputs.data(), outputs.size(), + outputs.data(), output_names, + /*opts=*/nullptr, s_); + if (expect_failure) { + ASSERT_EQ(func_, nullptr); + return; + } + + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(func_, nullptr); + TF_GraphAddFunction(host_graph_, func_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + TF_Operation* Use(const std::vector<TF_Operation*>& inputs) { + return UseT(ToOutput(inputs)); + } + + TF_Operation* UseT(const std::vector<TF_Output>& inputs) { + TF_Operation* op; + UseHelper(inputs, &op); + return op; + } + + // All the *Helper methods are used as a workaround for the restrictions that + // one cannot call ASSERT_* methods in non-void-returning functions (when + // exceptions are disabled during compilation) + void UseHelper(const std::vector<TF_Output>& inputs, TF_Operation** op) { + TF_OperationDescription* desc = + TF_NewOperation(host_graph_, func_name_, func_node_name_); + for (auto input : inputs) { + TF_AddInput(desc, input); + } + // Set device to CPU because some ops inside the function might not be + // available on GPU. + TF_SetDevice(desc, "/cpu:0"); + *op = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(*op, nullptr); + } + + FunctionDef fdef() { + tensorflow::FunctionDef fdef; + EXPECT_TRUE(GetFunctionDef(func_, &fdef)); + return fdef; + } + + // logging utility + template <class Container> + string ToString(const Container& v) { + std::stringstream ss; + ss << "{"; + size_t i = 0; + for (const auto& e : v) { + if (i != 0) { + ss << ", "; + } + ss << e.ToString(); + ++i; + } + ss << "}"; + return ss.str(); + } + + void VerifyFDefNodes(const tensorflow::FunctionDef& fdef, + const std::unordered_set<string>& nodes) { + ASSERT_EQ(nodes.size(), fdef.node_def_size()) + << "Got unexpected number of nodes. Expected: [" + << str_util::Join(nodes, ", ") + << "] Actual nodes in fdef: " << fdef.DebugString(); + for (const NodeDef& node_def : fdef.node_def()) { + ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end()) + << "Got unexpected node: " << node_def.name() + << " in fdef: " << fdef.DebugString(); + } + } + + void VerifyFDefInputs(const tensorflow::FunctionDef& fdef, + const std::vector<IOSpec>& inputs) { + const OpDef& signature = fdef.signature(); + ASSERT_EQ(inputs.size(), signature.input_arg_size()); + for (int i = 0; i < inputs.size(); ++i) { + const OpDef::ArgDef& arg = signature.input_arg(i); + const IOSpec& in = inputs[i]; + if (in.second != DT_INVALID) { + ASSERT_EQ(arg.type(), in.second) + << "Got unexpected type for input " << i + << ". fdef: " << fdef.DebugString(); + } + ASSERT_EQ(arg.name(), in.first) << "Got unexpected name for input " << i + << ". fdef: " << fdef.DebugString(); + } + } + + void VerifyFDefOutputs(const tensorflow::FunctionDef& fdef, + const std::vector<IOSpec>& outputs) { + const OpDef& signature = fdef.signature(); + ASSERT_EQ(outputs.size(), signature.output_arg_size()); + for (int i = 0; i < outputs.size(); ++i) { + const OpDef::ArgDef& arg = signature.output_arg(i); + const IOSpec& out = outputs[i]; + if (out.second != DT_INVALID) { + ASSERT_EQ(arg.type(), out.second) + << "Got unexpected type for output " << i + << ". fdef: " << fdef.DebugString(); + } + ASSERT_EQ(arg.name(), out.first) << "Got unexpected name for output " << i + << ". fdef: " << fdef.DebugString(); + } + } + + void VerifyFDefEdges( + const tensorflow::FunctionDef& fdef, + const std::vector<EdgeSpec>& e_edges, // expected edges + const std::vector<EdgeSpec>& c_edges, // expected ctrl edges + bool is_exact_edges = true) { + // Build a set of edges from fdef + std::set<EdgeSpec> a_edges; // actual edges + // Get edges from inputs to body nodes and between body nodes + for (const NodeDef& node_def : fdef.node_def()) { + for (int i = 0; i < node_def.input_size(); ++i) { + const string& in = node_def.input(i); + const auto& v = + a_edges.insert({in, strings::StrCat(node_def.name(), ":", i)}); + ASSERT_TRUE(v.second) << "Duplicate edge " << in << " -> " + << strings::StrCat(node_def.name(), ":", i) + << ". fdef: " << fdef.DebugString(); + } + } + // Get edges from body nodes to outputs and from inputs to outputs + for (const OpDef::ArgDef& arg : fdef.signature().output_arg()) { + const auto& iter = fdef.ret().find(arg.name()); + if (iter != fdef.ret().end()) { + const auto& v = a_edges.insert({iter->second, arg.name()}); + ASSERT_TRUE(v.second) << "Duplicate edge " << iter->second << " -> " + << arg.name() << ". fdef: " << fdef.DebugString(); + } else { + const auto& v = a_edges.insert({arg.name(), arg.name()}); + ASSERT_TRUE(v.second) << "Duplicate edge " << arg.name() << " -> " + << arg.name() << ". fdef: " << fdef.DebugString(); + } + } + + // Verify edges + for (const EdgeSpec& e : e_edges) { + ASSERT_TRUE(a_edges.find(e) != a_edges.end()) + << "Failed to find expected edge " << e.ToString() + << " in fdef: " << fdef.DebugString(); + } + + // If caller specified all edges, check that we have seen all + if (is_exact_edges) { + ASSERT_EQ(e_edges.size() + c_edges.size(), a_edges.size()) + << "Expected edges: " << ToString(e_edges) + << " Expected Control edges: " << ToString(c_edges) + << " Actual edges: " << ToString(a_edges) + << " in fdef: " << fdef.DebugString(); + } + } + + void VerifyFDef(const std::unordered_set<string>& nodes, + const std::vector<IOSpec>& inputs, + const std::vector<IOSpec>& outputs, + const std::vector<EdgeSpec>& e_edges, // expected edges + const std::vector<EdgeSpec>& c_edges, // expected ctrl edges + bool is_exact_edges = true) { + tensorflow::FunctionDef fdef; + ASSERT_TRUE(GetFunctionDef(func_, &fdef)); + VerifyFDefNodes(fdef, nodes); + VerifyFDefInputs(fdef, inputs); + VerifyFDefOutputs(fdef, outputs); + VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges); + } + + const char* func_name_ = "MyFunc"; + const char* func_node_name_ = "MyFunc_0"; + TF_Status* s_; + TF_Graph* func_graph_; + TF_Graph* host_graph_; + TF_Function* func_; + + // Workaround for not being able to initialize empty map using {} + std::unordered_set<string> empty_; +}; + +TEST_F(CApiFunctionTest, OneOp_ZeroInputs_OneOutput) { + /* + * constant + * | + * v + */ + // Define + TF_Operation* c = ScalarConst(10, func_graph_, s_, "scalar10"); + Define(-1, {}, {}, {c}, nullptr); + + // Use, run, and verify + TF_Operation* func_op = Use({}); + Run({}, func_op, 10); + VerifyFDef({"scalar10_0"}, {}, {{"scalar10", DT_INT32}}, + {{"scalar10_0:output:0", "scalar10"}}, {}); +} + +TEST_F(CApiFunctionTest, OneOp_OneInput_OneOutput) { + /* + * | + * v + * negate + * | + * v + */ + // Define + TF_Operation* feed = Placeholder(func_graph_, s_); + TF_Operation* neg = Neg(feed, func_graph_, s_); + Define(-1, {}, {feed}, {neg}, nullptr); + + // Use, run, and verify + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, -3); + VerifyFDef({"neg_0"}, {{"feed", DT_INT32}}, {{"neg", DT_INT32}}, + {{"feed", "neg_0:0"}, {"neg_0:y:0", "neg"}}, {}); +} + +TEST_F(CApiFunctionTest, ZeroOps_Identity) { + /* + * | + * | + * | + * v + */ + // Define + TF_Operation* feed = Placeholder(func_graph_, s_); + Define(-1, {}, {feed}, {feed}, nullptr); + + // Use, run, and verify + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 3); + VerifyFDef(empty_, {{"feed", DT_INT32}}, {{"feed_0", DT_INT32}}, + {{"feed", "feed_0"}}, {}); +} + +TEST_F(CApiFunctionTest, ZeroOps_Permutation) { + /* + * | | + * \ / + * \/ + * x + * /\ + * / \ + * | | + * v v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + Define(-1, {}, {feed1, feed2}, {feed2, feed1}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2}); + VerifyFDef(empty_, M({{"feed1"}, {"feed2"}}), M({{"feed2_0"}, {"feed1_0"}}), + {{"feed1", "feed1_0"}, {"feed2", "feed2_0"}}, {}); +} + +TEST_F(CApiFunctionTest, OneOp_TwoInputs_OneOutput) { + /* + * | | + * v v + * add + * | + * v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + Define(-1, {}, {feed1, feed2}, {add}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), + {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, {}); +} + +TEST_F(CApiFunctionTest, OneOp_TwoInputs_ZeroOutputs) { + /* + * | | + * v v + * add + * + * (output ignored) + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + Add(feed1, feed2, func_graph_, s_); + Define(-1, {}, {feed1, feed2}, {}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + Use({two, func_feed}); + VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), {}, + {{"feed1", "add:0"}, {"feed2", "add:1"}}, {}); +} + +TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_OneOutput) { + /* + * | | | + * v v / + * add1 / + * | | + * v v + * add2 + * | + * v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3"); + TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1"); + TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2"); + Define(-1, {}, {feed1, feed2, feed3}, {add2}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); + TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, ten, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 10 + 3); + VerifyFDef({"add1", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}), + M({{"add2"}}), + {{"feed1", "add1:0"}, + {"feed2", "add1:1"}, + {"add1:sum:0", "add2_0:0"}, + {"feed3", "add2_0:1"}, + {"add2_0:sum:0", "add2"}}, + {}); +} + +TEST_F(CApiFunctionTest, OneOp_TwoInputs_TwoDuplicateOutputs) { + /* + * | | + * v v + * add + * | + * +-+-+ + * | | + * v v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + Define(-1, {}, {feed1, feed2}, {add, add}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5}); + VerifyFDef({"add_1"}, M({{"feed1"}, {"feed2"}}), M({{"add"}, {"add_0"}}), + {{"feed1", "add_1:0"}, + {"feed2", "add_1:1"}, + {"add_1:sum:0", "add"}, + {"add_1:sum:0", "add_0"}}, + {}); +} + +TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_TwoOutputs) { + /* + * | | | + * v v / + * add / + * | | + * +-+ | + * | | | + * | v v + * | add + * | | + * v v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3"); + TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1"); + TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2"); + Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); + TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, ten, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15}); + VerifyFDef({"add1_0", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}), + M({{"add1"}, {"add2"}}), + {{"feed1", "add1_0:0"}, + {"feed2", "add1_0:1"}, + {"add1_0:sum:0", "add2_0:0"}, + {"feed3", "add2_0:1"}, + {"add1_0:sum:0", "add1"}, + {"add2_0:sum:0", "add2"}}, + {}); +} + +TEST_F(CApiFunctionTest, FromSubsetOfOps) { + /* + * | | | + * v v / + * add / + * | | + * +---+--+---+ + * Ops used | | | | + * for func | v v | + * | | add | + * +-------> | | | + * | v | + * | | + * +----------+ + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3"); + TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1"); + TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2"); + Define(1, {add2}, {add1, feed3}, {add2}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add2_0"}, M({{"add1"}, {"feed3"}}), M({{"add2"}}), + {{"add1", "add2_0:0"}, {"feed3", "add2_0:1"}, {"add2_0:sum:0", "add2"}}, + {}); +} + +TEST_F(CApiFunctionTest, UsingOneOutputOfSplit) { + /* + * feed + * | + * +---------+---+ + * | const0 | | + * | | | | + * | v / | + * | split | + * | | | | | + * | v | v | + * | | | + * +------+------+ + * | + * v + * + * Only the second output from split is used as function output + */ + // Define + TF_Operation* feed = Placeholder(func_graph_, s_); + TF_Operation* split = Split3(feed, func_graph_, s_); + DefineT(-1, {}, {{feed, 0}}, {{split, 1}}, nullptr); + + // Use, run, and verify + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed}); + RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, {{func_op, 0}}, + {{3, 4}}); + VerifyFDef({"split3_const0", "split3_0"}, M({{"feed"}}), M({{"split3"}}), + {{"split3_const0:output:0", "split3_0:0"}, + {"feed", "split3_0:1"}, + {"split3_0:output:1", "split3"}}, + {}); +} + +TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplit) { + /* + * feed + * | + * +---------+---+ + * | const0 | | + * | | | | + * | v / | + * | split | + * | | | | | + * | | v | | + * | | | | + * +---+-----+---+ + * | | + * v v + * + * Second output from split is not used as function output + */ + // Define + TF_Operation* feed = Placeholder(func_graph_, s_); + TF_Operation* split = Split3(feed, func_graph_, s_); + DefineT(-1, {}, {{feed, 0}}, {{split, 0}, {split, 2}}, nullptr); + + // Use, run, and verify + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed}); + RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, + {{func_op, 0}, {func_op, 1}}, {{1, 2}, {5, 6}}); + VerifyFDef({"split3_const0", "split3_1"}, M({{"feed"}}), + M({{"split3"}, {"split3_0"}}), + {{"split3_const0:output:0", "split3_1:0"}, + {"feed", "split3_1:1"}, + {"split3_1:output:0", "split3"}, + {"split3_1:output:2", "split3_0"}}, + {}); +} + +TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplitAsInputs) { + /* + * | + * v + * split + * | | | + * | v | + * | | + * +---+-----+---+ + * | | | | + * | v v | + * | add | + * | | | + * | | | + * +------+------+ + * | + * v + */ + // Define + TF_Operation* feed = Placeholder(func_graph_, s_); + TF_Operation* split = Split3(feed, func_graph_, s_); + TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + DefineT(1, {add}, {{split, 0}, {split, 2}}, {{add, 0}}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add_0"}, M({{"split3"}, {"split3_0"}}), M({{"add"}}), + {{"split3", "add_0:0"}, {"split3_0", "add_0:1"}, {"add_0:sum:0", "add"}}, + {}); +} + +TEST_F(CApiFunctionTest, NodesUsedInInputsMustHaveSingleOutput) { + /* + * | + * v + * split + * | | | + * | v | + * | | + * input --->| |<--- input + * | | + * v v + * add + * | + * | + * v + */ + // Define + TF_Tensor* tensor_123 = Int32Tensor({1, 2, 3}); + TF_Operation* c = Const(tensor_123, func_graph_, s_, "const_array"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* split = Split3(c, func_graph_, s_); + TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + DefineT(-1, {}, {{split, 0}, {split, 2}}, {{add, 0}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("When `num_opers` is set to -1, nodes referenced in " + "`inputs` must have a single output. Node split3 has " + "3 outputs. Encountered while creating function 'MyFunc'"), + string(TF_Message(s_))); + + TF_DeleteTensor(tensor_123); +} + +TEST_F(CApiFunctionTest, FunctionWithWhileLoop) { + // Inputs to the while loop and the function as a whole + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + + // Outputs of the while loop corresponding to the two inputs above + // The first one will the function's output + std::vector<TF_Output> outputs; + + // Add while loop to func_graph_ + { + // The inputs to the while loop + std::vector<TF_Output> inputs = {{feed1, 0}, {feed2, 0}}; + std::unique_ptr<TF_WhileParams> params(new TF_WhileParams( + TF_NewWhile(func_graph_, &inputs[0], inputs.size(), s_))); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params->name = "test_loop"; + + // Initialize outputs so we can easily detect errors/bugs + outputs.resize(2, {nullptr, -1}); + + // Create loop: while (input1 < input2) input1 += input2 + 1 + TF_Operation* less_than = LessThan( + params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params->cond_output = {less_than, 0}; + + TF_Operation* add1 = Add(params->body_inputs[0], params->body_inputs[1], + params->body_graph, s_, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* one = ScalarConst(1, params->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* add2 = Add(add1, one, params->body_graph, s_, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params->body_outputs[0] = {add2, 0}; + params->body_outputs[1] = params->body_inputs[1]; + + // Finalize while loop + TF_FinishWhile(params.get(), s_, &outputs[0]); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + // Define function, use it in graph, and run + DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {outputs[0]}, nullptr); + TF_Operation* five = ScalarConst(5, host_graph_, s_, "five"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed, five}); + Run({{func_feed, Int32Tensor(2)}}, func_op, 2 /*+=*/ + 5 + 1); + + // Verify input, output, and subset of edges in fdef. + // The subset of edges we verify is a chain between feed1 and output to + // make sure that the correct output is picked. + tensorflow::FunctionDef fdef; + ASSERT_TRUE(GetFunctionDef(func_, &fdef)); + VerifyFDefInputs(fdef, M({{"feed1"}, {"feed2"}})); + VerifyFDefOutputs(fdef, M({{"test_loop_exit"}})); + VerifyFDefEdges(fdef, + {{"feed1", "test_loop/Enter:0"}, + {"test_loop/Enter:output:0", "test_loop/Merge:0"}, + {"test_loop/Merge:output:0", "test_loop/Switch:0"}, + {"test_loop/Switch:output_false:0", "test_loop/Exit:0"}, + {"test_loop/Exit:output:0", "test_loop_exit"}}, + {}, false); +} + +TEST_F(CApiFunctionTest, ControlDependency) { + /* + * | | scalar + * | | . + * v v . <---- control dependency + * add < - + * | + * v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* five = ScalarConst(5, func_graph_, s_); + TF_Operation* add = + AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + Define(-1, {}, {feed1, feed2}, {add}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), + {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, + {{"scalar", "add_0"}}); +} + +TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) { + /* + * | | scalar + * | | . + * v v . <---- control dependency + * add < - + * | + * v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* five = ScalarConst(5, func_graph_, s_); + TF_Operation* add = + AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + Define(1, {add}, {feed1, feed2}, {add}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("The source of control edge [id=3 scalar:-1 -> add:-1] " + "is not in the body. Encountered while creating " + "function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) { + /* + * | |. + * | | . + * | | . + * v v . <---- control dependency + * add < - + * | + * v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = + AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + Define(-1, {}, {feed1, feed2}, {add}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("The source of control edge [id=3 feed1:-1 -> add:-1] " + "is not in the body. Encountered while creating " + "function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) { + /* + * feed + * | + * +++ + * | | + * +---+-+---+ + * | | | | + * | v v | + * | add | + * | | | + * | | | + * +----+----+ + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* add = Add(feed1, feed1, func_graph_, s_); + Define(-1, {}, {feed1, feed1}, {add}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ( + string("TF_Output feed1:0 appears more than once in the input list"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) { + /* + * | | + * v v + * add + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does " + "not have output 2\n\tEncountered while processing " + "input 1 into function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, InvalidInputTensor_BadNodePtr) { + /* + * | | + * v v + * add + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + DefineT(-1, {}, {{feed1, 0}, {nullptr, 0}}, {{add, 0}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Node is null\n\tEncountered while processing input 1 " + "into function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) { + /* + * | | + * v v + * add + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does " + "not have output 3\n\tEncountered while processing " + "output 0 from function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) { + /* + * | | + * v v + * add + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + Add(feed1, feed2, func_graph_, s_); + DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{nullptr, 3}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 " + "from function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, NodeMissingInput) { + /* + * input---> | | <----missing input + * v v + * body----> add + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + DefineT(1, {add}, {{feed1, 0}}, {{add, 0}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Input 1, 'feed2:0', of node 'add' in function 'MyFunc' " + "is not available. You might need to include it in inputs " + "or include its source node in the body"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, OutputOpNotInBody) { + /* + * | | + * v v + * add scalar (scalar not included in body) + * | | + * v v (function has two outputs) + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* scalar = ScalarConst(2, func_graph_, s_); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + Define(1, {add}, {feed1, feed2}, {add, scalar}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("TF_Output scalar:0 is neither in the function body nor " + "among function inputs. Encountered while creating " + "function 'MyFunc'"), + string(TF_Message(s_))); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index f7d25dce8f..6e44a72e2b 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -130,6 +130,11 @@ struct TF_DeviceList { std::vector<tensorflow::DeviceAttributes> response; }; +struct TF_Function { + // Currently contains a single function and no gradients + tensorflow::FunctionDefLibrary fdef_lib; +}; + namespace tensorflow { class TensorCApi { @@ -142,6 +147,9 @@ class TensorCApi { }; TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); + +Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 0aa60fb45d..c442029009 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -829,7 +829,7 @@ TEST(CAPI, ShapeInferenceError) { TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3"); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_Operation* add = Add(vec2, vec3, graph, status); + TF_Operation* add = AddNoCheck(vec2, vec3, graph, status); ASSERT_NE(TF_OK, TF_GetCode(status)); ASSERT_TRUE(add == nullptr); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 21603c1a07..9cd978c97e 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" using tensorflow::GraphDef; @@ -36,6 +38,23 @@ TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { return t; } +TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims, + const int32_t* values) { + int64_t num_values = 1; + for (int i = 0; i < num_dims; ++i) { + num_values *= dims[i]; + } + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values); + memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values); + return t; +} + +TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) { + int64_t dims = values.size(); + return Int32Tensor(&dims, 1, values.data()); +} + TF_Tensor* Int32Tensor(int32_t v) { const int num_bytes = sizeof(int32_t); int32_t* values = new int32_t[1]; @@ -44,19 +63,40 @@ TF_Tensor* Int32Tensor(int32_t v) { &Int32Deallocator, nullptr); } -TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) { +// All the *Helper methods are used as a workaround for the restrictions that +// one cannot call ASSERT_* methods in non-void-returning functions (when +// exceptions are disabled during compilation) +void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, + TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); TF_SetAttrType(desc, "dtype", TF_INT32); - return TF_FinishOperation(desc, s); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); } -TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, - const char* name) { +TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) { + TF_Operation* op; + PlaceholderHelper(graph, s, name, &op); + return op; +} + +void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name, + TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); TF_SetAttrTensor(desc, "value", t, s); - if (TF_GetCode(s) != TF_OK) return nullptr; + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_SetAttrType(desc, "dtype", TF_TensorType(t)); - return TF_FinishOperation(desc, s); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, + const char* name) { + TF_Operation* op; + ConstHelper(t, graph, s, name, &op); + return op; } TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, @@ -65,11 +105,39 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, return Const(tensor.get(), graph, s, name); } +void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, + const char* name, TF_Operation** op, bool check) { + TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); + TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; + TF_AddInputList(desc, add_inputs, 2); + *op = TF_FinishOperation(desc, s); + if (check) { + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); + } +} + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name) { + TF_Operation* op; + AddHelper(l, r, graph, s, name, &op, true); + return op; +} + +TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + TF_Operation* op; + AddHelper(l, r, graph, s, name, &op, false); + return op; +} + +TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, + TF_Graph* graph, TF_Operation* ctrl_op, + TF_Status* s, const char* name) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; TF_AddInputList(desc, add_inputs, 2); + TF_AddControlInput(desc, ctrl_op); return TF_FinishOperation(desc, s); } @@ -81,11 +149,20 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, return TF_FinishOperation(desc, s); } -TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { +void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s, + TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg"); TF_Output neg_input = {n, 0}; TF_AddInput(desc, neg_input); - return TF_FinishOperation(desc, s); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { + TF_Operation* op; + NegHelper(n, graph, s, &op); + return op; } TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, @@ -96,6 +173,32 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, return TF_FinishOperation(desc, s); } +void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s, + const char* name, TF_Operation** op) { + TF_Operation* zero = ScalarConst( + 0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str()); + TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name); + TF_AddInput(desc, {zero, 0}); + TF_AddInput(desc, {input, 0}); + TF_SetAttrInt(desc, "num_split", 3); + TF_SetAttrType(desc, "T", TF_INT32); + // Set device to CPU since there is no version of split for int32 on GPU + // TODO(iga): Convert all these helpers and tests to use floats because + // they are usually available on GPUs. After doing this, remove TF_SetDevice + // call in c_api_function_test.cc + TF_SetDevice(desc, "/cpu:0"); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, + const char* name) { + TF_Operation* op; + Split3Helper(input, graph, s, name, &op); + return op; +} + bool IsPlaceholder(const tensorflow::NodeDef& node_def) { if (node_def.op() != "Placeholder" || node_def.name() != "feed") { return false; @@ -196,6 +299,18 @@ bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) { return ret; } +bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) { + TF_Status* s = TF_NewStatus(); + TF_Buffer* buffer = TF_NewBuffer(); + TF_FunctionToFunctionDef(func, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length); + TF_DeleteBuffer(buffer); + TF_DeleteStatus(s); + return ret; +} + bool GetAttrValue(TF_Operation* oper, const char* attr_name, tensorflow::AttrValue* attr_value, TF_Status* s) { TF_Buffer* buffer = TF_NewBuffer(); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 0c0ba667bd..a927739d46 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -33,6 +33,13 @@ typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); +// Create a tensor with values of type TF_INT32 provided by `values`. +TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims, + const int32_t* values); + +// Create 1 dimensional tensor with values from `values` +TF_Tensor* Int32Tensor(const std::vector<int32_t>& values); + TF_Tensor* Int32Tensor(int32_t v); TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, @@ -47,6 +54,13 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "add"); +TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "add"); + +TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, + TF_Graph* graph, TF_Operation* ctrl_op, + TF_Status* s, const char* name = "add"); + TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name = "add"); @@ -54,6 +68,10 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s); TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); +// Split `input` along the first dimention into 3 tensors +TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, + const char* name = "split3"); + bool IsPlaceholder(const tensorflow::NodeDef& node_def); bool IsScalarConst(const tensorflow::NodeDef& node_def, int v); @@ -66,6 +84,8 @@ bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def); +bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def); + bool GetAttrValue(TF_Operation* oper, const char* attr_name, tensorflow::AttrValue* attr_value, TF_Status* s); diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index 87d946c346..c5a1018127 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -18,6 +18,7 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/c_api.h" + "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 7d938365c5..a274c79970 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -523,6 +523,17 @@ Status Graph::IsValidNode(const Node* node) const { return Status::OK(); } +Status Graph::IsValidOutputTensor(const Node* node, int idx) const { + TF_RETURN_IF_ERROR(IsValidNode(node)); + if (idx >= node->num_outputs()) { + return errors::InvalidArgument("Node '", node->name(), "' (type: '", + node->op_def().name(), + "', num of outputs: ", node->num_outputs(), + ") does not have ", "output ", idx); + } + return Status::OK(); +} + Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props, const Node* cost_node) { Node* node = nullptr; @@ -572,7 +583,7 @@ int Graph::InternDeviceName(const string& device_name) { } string Edge::DebugString() const { - return strings::Printf("Edge %d %s:%d -> %s:%d", id_, src_->name().c_str(), + return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(), src_output_, dst_->name().c_str(), dst_input_); } diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 51ede642d2..25875185e4 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -519,6 +519,10 @@ class Graph { // Returns OK if `node` is non-null and belongs to this graph Status IsValidNode(const Node* node) const; + // Returns OK if IsValidNode(`node`) and `idx` is less than + // node->num_outputs() + Status IsValidOutputTensor(const Node* node, int idx) const; + // TODO(josh11b): uint64 hash() const; private: diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 08dd3922db..fa49e66e87 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -373,6 +373,33 @@ def TF_Reset(target, containers=None, config=None): TF_DeleteSessionOptions(opts) %} +// We use TF_GraphToFunction_wrapper instead of TF_GraphToFunction +%ignore TF_GraphToFunction; +// TF_GraphToFunction_wrapper does not use any Python methods and +// does not require GIL to be held. +%unignore TF_GraphToFunction_wrapper; + +// $input is a Python list of wrapped TF_Operations +%typemap(in) (const std::vector<TF_Operation*>* opers) + (std::vector<TF_Operation*> opers) { + if ($input != Py_None) { + if (!PyList_Check($input)) { + SWIG_exception_fail(SWIG_TypeError, "$symname: expected list"); + } + size_t size = PyList_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* item = PyList_GetItem($input, i); + TF_Operation* oper_ptr; + SWIG_ConvertPtr(item, reinterpret_cast<void**>(&oper_ptr), + $descriptor(TF_Operation*), 0); + opers.push_back(oper_ptr); + } + $1 = &opers; + } else { + $1 = nullptr; + } +} + %include "tensorflow/python/client/tf_session_helper.h" %unignoreall diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 60a589fa8b..72f560fa87 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -337,4 +337,38 @@ std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper( return control_inputs; } +TF_Function* TF_GraphToFunction_wrapper(const TF_Graph* fn_body, + const char* fn_name, + const std::vector<TF_Operation*>* opers, + const std::vector<TF_Output>& inputs, + const std::vector<TF_Output>& outputs, + const NameVector& output_names, + const TF_FunctionOptions* opts, + TF_Status* out_status) { + if (!output_names.empty() && output_names.size() != outputs.size()) { + Set_TF_Status_from_Status( + out_status, + errors::InvalidArgument( + "output names must be either empty or equal in size to outputs. ", + "output names size = ", output_names.size(), + " outputs size = ", outputs.size())); + return nullptr; + } + + int nopers = -1; + const TF_Operation* const* opers_array = nullptr; + if (opers != nullptr) { + nopers = opers->size(); + opers_array = opers->data(); + } + + const char** output_names_ptr = + output_names.empty() ? nullptr + : const_cast<const char**>(output_names.data()); + + return TF_GraphToFunction(fn_body, fn_name, nopers, opers_array, + inputs.size(), inputs.data(), outputs.size(), + outputs.data(), output_names_ptr, opts, out_status); +} + } // namespace tensorflow diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 3bc63f822f..8fae6206c0 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -148,6 +148,16 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper( TF_Operation* oper); +// `opers` equaling NULL are converted to `nopers = -1`. +// `output_names` must be empty or have the same length as `outputs`. +TF_Function* TF_GraphToFunction_wrapper(const TF_Graph* fn_body, + const char* fn_name, + const std::vector<TF_Operation*>* opers, + const std::vector<TF_Output>& inputs, + const std::vector<TF_Output>& outputs, + const NameVector& output_names, + const TF_FunctionOptions* opts, + TF_Status* out_status); } // namespace tensorflow #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 2f35f0e04b..7a866ee6e8 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -26,7 +26,9 @@ import hashlib from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import op_def_pb2 +from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -290,6 +292,7 @@ class _DefinedFunction(object): self._shape_func = shape_func self._extra_kwargs = kwargs self._definition = None # Constructed lazily. + self._c_func = None # Constructed with definition. self._sub_functions = dict() # Constructed with definition. self._args = [] @@ -396,6 +399,22 @@ class _DefinedFunction(object): if self._func.__doc__: self._definition.signature.description = self._func.__doc__ + # pylint: disable=protected-access + if temp_graph._c_graph: + with errors.raise_exception_on_not_ok_status() as status: + output_names = ([compat.as_bytes(x) for x in self._out_names] + if self._out_names else []) + self._c_func = c_api.TF_GraphToFunction_wrapper( + temp_graph._c_graph, + self._func_name, + None, # opers + [t._as_tf_output() for t in inputs], + [t._as_tf_output() for t in outputs], + output_names, + None, # opts + status) + # pylint: enable=protected-access + def _create_hash_str(self, input_arg, output_arg, node_def): """Creates an 8-character string unique to this input. diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 589db9ef4d..40205ddf05 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops @@ -63,7 +64,51 @@ def _OptimizerOptions(): do_constant_folding=cfold))) -class FunctionTest(test.TestCase): +class FunctionTestMethods(object): + """Test methods for verifying Function support. + + These test methods are used as mix-ins in two test cases: with + and without C API support. + """ + + def testIdentity(self): + + @function.Defun(dtypes.float32, func_name="MyIdentity") + def MyIdentityFunc(a): + return a + + with ops.Graph().as_default(): + call = MyIdentityFunc([18.0]) + self.assertEqual("MyIdentity", call.op.name) + with session.Session() as sess: + self.assertAllEqual([18.0], sess.run(call)) + + def testIdentityOutputName(self): + + @function.Defun( + dtypes.float32, func_name="MyIdentity", out_names=["my_result_name"]) + def MyIdentityFunc(a): + return a + + with ops.Graph().as_default(): + call = MyIdentityFunc([18.0]) + self.assertEqual("MyIdentity", call.op.name) + with session.Session() as sess: + self.assertAllEqual([18.0], sess.run(call)) + + def testTooManyOutputNames(self): + + @function.Defun( + dtypes.float32, func_name="MyIdentity", + out_names=["my_result1", "my_result2"]) + def MyIdentityFunc(a): + return a + + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, (r"Length of out_names \(2\) does not match number of " + r"outputs \(1\): my_result1, my_result2")): + MyIdentityFunc([18.0]) def testDefineFunction2Args(self): @@ -77,6 +122,35 @@ class FunctionTest(test.TestCase): with session.Session() as sess: self.assertAllEqual([5.0], sess.run(call)) + def testValueErrorOnFunctionWithNoOutput(self): + # TODO(iga): Remove this restriction and this test + + @function.Defun(dtypes.float32, dtypes.float32) + def APlus2B(a, b): + print(a + b * 2) # Create some ops to have nodes in the body + # Using 'print' to make lint happy + + with ops.Graph().as_default(): + with self.assertRaisesRegexp(ValueError, + "Function can not return None"): + APlus2B([1.0], [2.0]) + + def testDefineFunction2ArgsOutputName(self): + + @function.Defun( + dtypes.float32, + dtypes.float32, + func_name="APlus2B", + out_names=["my_result_name"]) + def APlus2B(a, b): + return a + b * 2 + + with ops.Graph().as_default(): + call = APlus2B([1.0], [2.0]) + self.assertEqual("APlus2B", call.op.name) + with session.Session() as sess: + self.assertAllEqual([5.0], sess.run(call)) + def testDefineFunctionDuplicateOutputs(self): @function.Defun(dtypes.float32, func_name="Duplicate") @@ -137,6 +211,7 @@ class FunctionTest(test.TestCase): out, = sess.run(dx, feed) self.assertAllClose(1 - np.square(np.tanh(inp)), out) + @test_util.disable_c_api # Function gradients don't work with C API def testCustomGradient(self): dtype = dtypes.float32 @@ -169,6 +244,7 @@ class FunctionTest(test.TestCase): out, = sess.run(dlogits, {logits: x, labels: y}) self.assertAllClose(out, np.exp(prob - y)) + @test_util.disable_c_api # Function gradients don't work with C API def testCustomGradientError(self): dtype = dtypes.float32 @@ -194,6 +270,7 @@ class FunctionTest(test.TestCase): "SymGrad expects to return 1.*but get 2.*instead"): _ = sess.run(dinp, {inp: x}) + @test_util.disable_c_api # Function gradients don't work with C API def testSymGradShape(self): g = ops.Graph() with g.as_default(): @@ -209,6 +286,7 @@ class FunctionTest(test.TestCase): self.assertEqual(x.get_shape(), dx.get_shape()) self.assertEqual(y.get_shape(), dy.get_shape()) + @test_util.disable_c_api # Function gradients don't work with C API def testSymGradAttr(self): @function.Defun(noinline=True) @@ -312,6 +390,7 @@ class FunctionTest(test.TestCase): "assertion failed.*-3"): self.assertAllEqual(Foo(constant_op.constant(-3.0)).eval(), 6.0) + @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API def testAssertWrapper(self): @function.Defun(dtypes.float32) @@ -326,6 +405,7 @@ class FunctionTest(test.TestCase): "assertion"): _ = MyFn(100.0).eval() + @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API def testWhileLoopCallsFunc(self): with self.test_session(use_gpu=True) as sess: @@ -345,6 +425,7 @@ class FunctionTest(test.TestCase): ans = sess.run(loop) self.assertAllClose(ans, 131072.) + @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API def testControlFlowStrictness(self): """Inlined functions must not execute in a untaken control flow branch.""" @@ -607,6 +688,7 @@ class FunctionTest(test.TestCase): self.assertAllClose(vals[0], vals[1]) self.assertAllClose(vals[2], vals[3]) + @test_util.disable_c_api # Function Declaration doesn't work with C API def testDeclare(self): foo = function.Declare("Foo", [("x", dtypes.float32)], [("y", dtypes.float32)]) @@ -626,6 +708,7 @@ class FunctionTest(test.TestCase): expected = rand * rand + 1.0 self.assertAllClose(expected, y.eval(feed_dict={x: rand})) + @test_util.disable_c_api # Function Declaration doesn't work with C API def testDeclareUsedInDefun(self): foo = function.Declare("Foo", [("x", dtypes.float32)], [("y", dtypes.float32)]) @@ -649,6 +732,7 @@ class FunctionTest(test.TestCase): expected = rand * rand + 1.0 self.assertAllClose(expected, y.eval(feed_dict={x: rand})) + @test_util.disable_c_api # Function Declaration doesn't work with C API def testDeclareTypeMistake(self): foo = function.Declare("Foo", [("x", dtypes.float32)], [("y", dtypes.float32)]) @@ -861,6 +945,32 @@ class FunctionTest(test.TestCase): self.assertEqual(len(f.signature.input_arg), 3) +class FunctionTest(FunctionTestMethods, test.TestCase): + """Test case that invokes test methods with _USE_C_API=False.""" + + def setUp(self): + self.prev_use_c_api = ops._USE_C_API + ops._USE_C_API = False + super(FunctionTest, self).setUp() + + def tearDown(self): + ops._USE_C_API = self.prev_use_c_api + super(FunctionTest, self).tearDown() + + +class FunctionWithCApiTest(FunctionTestMethods, test.TestCase): + """Test case that invokes test methods with _USE_C_API=True.""" + + def setUp(self): + self.prev_use_c_api = ops._USE_C_API + ops._USE_C_API = True + super(FunctionWithCApiTest, self).setUp() + + def tearDown(self): + ops._USE_C_API = self.prev_use_c_api + super(FunctionWithCApiTest, self).tearDown() + + class FunctionsFromProtos(test.TestCase): def expectFunctionsEqual(self, func, grad_func=None, new_func=None): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ccaa2141b5..659bc394b9 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2948,6 +2948,14 @@ class Graph(object): if self._graph_def_versions.min_consumer < 12: self._graph_def_versions.min_consumer = 12 self._functions[name] = function + if self._c_graph: + # pylint: disable=protected-access + assert function._c_func, ( + "Cannot add function created without C API support to graph " + "created with C API support") + with errors.raise_exception_on_not_ok_status() as status: + c_api.TF_GraphAddFunction(self._c_graph, function._c_func, status) + # pylint: enable=protected-access @property def building_function(self): |