diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-04-04 16:53:37 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-04 18:05:56 -0700 |
commit | 39c74da44a8fe40e83d3b0ec24904bc9ce303581 (patch) | |
tree | 59bef5ce34c9b068137b30ce8ed759dbd2adc44c /tensorflow/core/graph/graph.cc | |
parent | c354ba3470a56e271cff63e037509d9fbb253110 (diff) |
Make ImportGraphDef() work with functions.
In addition to modify graph_constructor.cc, this patch adds some other
functionality to enable importing fucntions:
* Ability to add FunctionDefLibraries to Graphs and
FunctionLibraryDefinitions (in addition to existing functions)
* FunctionDefsEqual() utility function
Change: 152205258
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r-- | tensorflow/core/graph/graph.cc | 42 |
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 65baf4cd85..c19764d082 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -360,6 +360,45 @@ void Graph::RemoveEdge(const Edge* e) { free_edges_.push_back(del); } +Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { + for (const FunctionDef& fdef : fdef_lib.function()) { + const FunctionDef* preexisting_fdef = ops_.Find(fdef.signature().name()); + if (preexisting_fdef != nullptr) { + if (!FunctionDefsEqual(*preexisting_fdef, fdef)) { + return errors::InvalidArgument( + "Cannot add function '", fdef.signature().name(), + "' because a different function with the same name already " + "exists."); + } + // Ignore duplicate FunctionDefs + continue; + } + // TODO(skyewm): fix test breakages and reenable this check + // const OpDef* op_def; + // if (ops_.LookUpOpDef(fdef.signature().name(), &op_def).ok()) { + // return errors::InvalidArgument( + // "Cannot add function '", fdef.signature().name(), + // "' because an op with the same name already exists."); + // } + TF_RETURN_IF_ERROR(ops_.AddFunctionDef(fdef)); + } + for (const GradientDef& grad : fdef_lib.gradient()) { + string preexisting_grad_func = ops_.FindGradient(grad.function_name()); + if (!preexisting_grad_func.empty()) { + if (preexisting_grad_func != grad.gradient_func()) { + return errors::InvalidArgument( + "Cannot assign gradient function '", grad.gradient_func(), "' to '", + grad.function_name(), "' because it already has gradient function ", + "'", preexisting_grad_func, "'"); + } + // Ignore duplicate GradientDefs + continue; + } + TF_RETURN_IF_ERROR(ops_.AddGradientDef(grad)); + } + return Status::OK(); +} + namespace { void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { @@ -380,7 +419,8 @@ void Graph::ToGraphDef(GraphDef* graph_def) const { void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const { graph_def->Clear(); - graph_def->mutable_versions()->CopyFrom(versions()); + *graph_def->mutable_versions() = versions(); + *graph_def->mutable_library() = ops_.ToProto(); std::vector<const Edge*> inputs; // Construct this outside the loop for speed. for (auto id = from_node_id; id < num_node_ids(); ++id) { |