aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-04-04 16:53:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-04 18:05:56 -0700
commit39c74da44a8fe40e83d3b0ec24904bc9ce303581 (patch)
tree59bef5ce34c9b068137b30ce8ed759dbd2adc44c /tensorflow/core/graph/graph.cc
parentc354ba3470a56e271cff63e037509d9fbb253110 (diff)
Make ImportGraphDef() work with functions.
In addition to modify graph_constructor.cc, this patch adds some other functionality to enable importing fucntions: * Ability to add FunctionDefLibraries to Graphs and FunctionLibraryDefinitions (in addition to existing functions) * FunctionDefsEqual() utility function Change: 152205258
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r--tensorflow/core/graph/graph.cc42
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 65baf4cd85..c19764d082 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -360,6 +360,45 @@ void Graph::RemoveEdge(const Edge* e) {
free_edges_.push_back(del);
}
+Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
+ for (const FunctionDef& fdef : fdef_lib.function()) {
+ const FunctionDef* preexisting_fdef = ops_.Find(fdef.signature().name());
+ if (preexisting_fdef != nullptr) {
+ if (!FunctionDefsEqual(*preexisting_fdef, fdef)) {
+ return errors::InvalidArgument(
+ "Cannot add function '", fdef.signature().name(),
+ "' because a different function with the same name already "
+ "exists.");
+ }
+ // Ignore duplicate FunctionDefs
+ continue;
+ }
+ // TODO(skyewm): fix test breakages and reenable this check
+ // const OpDef* op_def;
+ // if (ops_.LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
+ // return errors::InvalidArgument(
+ // "Cannot add function '", fdef.signature().name(),
+ // "' because an op with the same name already exists.");
+ // }
+ TF_RETURN_IF_ERROR(ops_.AddFunctionDef(fdef));
+ }
+ for (const GradientDef& grad : fdef_lib.gradient()) {
+ string preexisting_grad_func = ops_.FindGradient(grad.function_name());
+ if (!preexisting_grad_func.empty()) {
+ if (preexisting_grad_func != grad.gradient_func()) {
+ return errors::InvalidArgument(
+ "Cannot assign gradient function '", grad.gradient_func(), "' to '",
+ grad.function_name(), "' because it already has gradient function ",
+ "'", preexisting_grad_func, "'");
+ }
+ // Ignore duplicate GradientDefs
+ continue;
+ }
+ TF_RETURN_IF_ERROR(ops_.AddGradientDef(grad));
+ }
+ return Status::OK();
+}
+
namespace {
void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
@@ -380,7 +419,8 @@ void Graph::ToGraphDef(GraphDef* graph_def) const {
void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
graph_def->Clear();
- graph_def->mutable_versions()->CopyFrom(versions());
+ *graph_def->mutable_versions() = versions();
+ *graph_def->mutable_library() = ops_.ToProto();
std::vector<const Edge*>
inputs; // Construct this outside the loop for speed.
for (auto id = from_node_id; id < num_node_ids(); ++id) {