aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/copy_graph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-05 08:36:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-05 09:41:47 -0700
commit8bf6ef1337359993a8be057c0dc90da8f5a6e4fa (patch)
treec7367050bf36d6f4b17a93d06700dc7169012ac1 /tensorflow/contrib/copy_graph
parent931e848c28e97e8cae410af242f8e09d75663ee4 (diff)
Merge changes from github.
Change: 121586635
Diffstat (limited to 'tensorflow/contrib/copy_graph')
-rw-r--r--tensorflow/contrib/copy_graph/BUILD42
-rw-r--r--tensorflow/contrib/copy_graph/__init__.py26
-rw-r--r--tensorflow/contrib/copy_graph/python/__init__.py15
-rw-r--r--tensorflow/contrib/copy_graph/python/util/__init__.py15
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py261
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_test.py110
6 files changed, 469 insertions, 0 deletions
diff --git a/tensorflow/contrib/copy_graph/BUILD b/tensorflow/contrib/copy_graph/BUILD
new file mode 100644
index 0000000000..5a775c2022
--- /dev/null
+++ b/tensorflow/contrib/copy_graph/BUILD
@@ -0,0 +1,42 @@
+# Description:
+# contains parts of TensorFlow that are experimental or unstable and which are not supported.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+py_library(
+ name = "copy_graph_py",
+ srcs = [
+ "__init__.py",
+ "python/util/__init__.py",
+ "python/util/copy_elements.py",
+ ],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "copy_test",
+ srcs = glob(["python/util/copy_test.py"]),
+ srcs_version = "PY2AND3",
+ deps = [
+ ":copy_graph_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/copy_graph/__init__.py b/tensorflow/contrib/copy_graph/__init__.py
new file mode 100644
index 0000000000..1b15f3eb73
--- /dev/null
+++ b/tensorflow/contrib/copy_graph/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions for copying elements from one graph to another.
+
+@@copy_op_to_graph
+@@copy_variable_to_graph
+@@get_copied_op
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.copy_graph.python.util.copy_elements import *
diff --git a/tensorflow/contrib/copy_graph/python/__init__.py b/tensorflow/contrib/copy_graph/python/__init__.py
new file mode 100644
index 0000000000..1dd1cb72be
--- /dev/null
+++ b/tensorflow/contrib/copy_graph/python/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
diff --git a/tensorflow/contrib/copy_graph/python/util/__init__.py b/tensorflow/contrib/copy_graph/python/util/__init__.py
new file mode 100644
index 0000000000..1dd1cb72be
--- /dev/null
+++ b/tensorflow/contrib/copy_graph/python/util/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
new file mode 100644
index 0000000000..9cfff05756
--- /dev/null
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -0,0 +1,261 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""## Functions for copying elements from one graph to another.
+
+These functions allow for recursive copying of elements (ops and variables)
+from one graph to another. The copied elements are initialized inside a
+user-specified scope in the other graph. There are separate functions to
+copy ops and variables.
+There is also a function to retrive the copied version of an op from the
+first graph inside a scope in the second graph.
+
+@@copy_op_to_graph
+@@copy_variable_to_graph
+@@get_copied_op
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from copy import deepcopy
+from tensorflow.python.ops.variables import Variable
+from tensorflow.python.client.session import Session
+from tensorflow.python.framework import ops
+
+__all__ = ["copy_op_to_graph", "copy_variable_to_graph", "get_copied_op"]
+
+
+def copy_variable_to_graph(org_instance, to_graph, scope=""):
+ """Given a `Variable` instance from one `Graph`, initializes and returns
+ a copy of it from another `Graph`, under the specified scope
+ (default `""`).
+
+ Args:
+ org_instance: A `Variable` from some `Graph`.
+ to_graph: The `Graph` to copy the `Variable` to.
+ scope: A scope for the new `Variable` (default `""`).
+
+ Returns:
+ The copied `Variable` from `to_graph`.
+
+ Raises:
+ TypeError: If `org_instance` is not a `Variable`.
+ """
+
+ if not isinstance(org_instance, Variable):
+ raise TypeError(str(org_instance) + " is not a Variable")
+
+ #The name of the new variable
+ if scope != "":
+ new_name = (scope + '/' +
+ org_instance.name[:org_instance.name.index(':')])
+ else:
+ new_name = org_instance.name[:org_instance.name.index(':')]
+
+ #Get the collections that the new instance needs to be added to.
+ #The new collections will also be a part of the given scope,
+ #except the special ones required for variable initialization and
+ #training.
+ collections = []
+ for name, collection in org_instance.graph._collections.items():
+ if org_instance in collection:
+ if (name == ops.GraphKeys.VARIABLES or
+ name == ops.GraphKeys.TRAINABLE_VARIABLES or
+ scope == ''):
+ collections.append(name)
+ else:
+ collections.append(scope + '/' + name)
+
+ #See if its trainable.
+ trainable = (org_instance in org_instance.graph.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES))
+ #Get the initial value
+ with org_instance.graph.as_default():
+ temp_session = Session()
+ init_value = temp_session.run(org_instance.initialized_value())
+
+ #Initialize the new variable
+ with to_graph.as_default():
+ new_var = Variable(init_value,
+ trainable,
+ name=new_name,
+ collections=collections,
+ validate_shape=False)
+
+ return new_var
+
+
+def copy_op_to_graph(org_instance, to_graph, variables,
+ scope=""):
+ """Given an `Operation` 'org_instance` from one `Graph`,
+ initializes and returns a copy of it from another `Graph`,
+ under the specified scope (default `""`).
+
+ The copying is done recursively, so any `Operation` whose output
+ is required to evaluate the `org_instance`, is also copied (unless
+ already done).
+
+ Since `Variable` instances are copied separately, those required
+ to evaluate `org_instance` must be provided as input.
+
+ Args:
+ org_instance: An `Operation` from some `Graph`. Could be a
+ `Placeholder` as well.
+ to_graph: The `Graph` to copy `org_instance` to.
+ variables: An iterable of `Variable` instances to copy `org_instance` to.
+ scope: A scope for the new `Variable` (default `""`).
+
+ Returns:
+ The copied `Operation` from `to_graph`.
+
+ Raises:
+ TypeError: If `org_instance` is not an `Operation` or `Tensor`.
+ """
+
+ #The name of the new instance
+ if scope != '':
+ new_name = scope + '/' + org_instance.name
+ else:
+ new_name = org_instance.name
+
+ #Extract names of variables
+ copied_variables = dict((x.name, x) for x in variables)
+
+ #If a variable by the new name already exists, return the
+ #correspondng tensor that will act as an input
+ if new_name in copied_variables:
+ return to_graph.get_tensor_by_name(
+ copied_variables[new_name].name)
+
+ #If an instance of the same name exists, return appropriately
+ try:
+ already_present = to_graph.as_graph_element(new_name,
+ allow_tensor=True,
+ allow_operation=True)
+ return already_present
+ except:
+ pass
+
+ #Get the collections that the new instance needs to be added to.
+ #The new collections will also be a part of the given scope.
+ collections = []
+ for name, collection in org_instance.graph._collections.items():
+ if org_instance in collection:
+ if scope == '':
+ collections.append(name)
+ else:
+ collections.append(scope + '/' + name)
+
+ #Take action based on the class of the instance
+
+ if isinstance(org_instance, ops.Tensor):
+
+ #If its a Tensor, it is one of the outputs of the underlying
+ #op. Therefore, copy the op itself and return the appropriate
+ #output.
+ op = org_instance.op
+ new_op = copy_op_to_graph(op, to_graph, variables, scope)
+ output_index = op.outputs.index(org_instance)
+ new_tensor = new_op.outputs[output_index]
+ #Add to collections if any
+ for collection in collections:
+ to_graph.add_to_collection(collection, new_tensor)
+
+ return new_tensor
+
+ elif isinstance(org_instance, ops.Operation):
+
+ op = org_instance
+
+ #If it has an original_op parameter, copy it
+ if op._original_op is not None:
+ new_original_op = copy_op_to_graph(op._original_op, to_graph,
+ variables, scope)
+ else:
+ new_original_op = None
+
+ #If it has control inputs, call this function recursively on each.
+ new_control_inputs = [copy_op_to_graph(x, to_graph, variables,
+ scope)
+ for x in op.control_inputs]
+
+ #If it has inputs, call this function recursively on each.
+ new_inputs = [copy_op_to_graph(x, to_graph, variables,
+ scope)
+ for x in op.inputs]
+
+ #Make a new node_def based on that of the original.
+ #An instance of tensorflow.core.framework.graph_pb2.NodeDef, it
+ #stores String-based info such as name, device and type of the op.
+ #Unique to every Operation instance.
+ new_node_def = deepcopy(op._node_def)
+ #Change the name
+ new_node_def.name = new_name
+
+ #Copy the other inputs needed for initialization
+ output_types = op._output_types[:]
+ input_types = op._input_types[:]
+
+ #Make a copy of the op_def too.
+ #Its unique to every _type_ of Operation.
+ op_def = deepcopy(op._op_def)
+
+ #Initialize a new Operation instance
+ new_op = ops.Operation(new_node_def,
+ to_graph,
+ new_inputs,
+ output_types,
+ new_control_inputs,
+ input_types,
+ new_original_op,
+ op_def)
+ #Use Graph's hidden methods to add the op
+ to_graph._add_op(new_op)
+ to_graph._record_op_seen_by_control_dependencies(new_op)
+ for device_function in reversed(to_graph._device_function_stack):
+ new_op._set_device(device_function(new_op))
+
+ return new_op
+
+ else:
+ raise TypeError("Could not copy instance: " + str(org_instance))
+
+
+def get_copied_op(org_instance, graph, scope=""):
+ """Given an `Operation` instance from some `Graph`, returns
+ its namesake from `graph`, under the specified scope
+ (default `""`).
+
+ If a copy of `org_instance` is present in `graph` under the given
+ `scope`, it will be returned.
+
+ Args:
+ org_instance: An `Operation` from some `Graph`.
+ graph: The `Graph` to be searched for a copr of `org_instance`.
+ scope: The scope `org_instance` is present in.
+
+ Returns:
+ The `Operation` copy from `graph`.
+ """
+
+ #The name of the copied instance
+ if scope != '':
+ new_name = scope + '/' + org_instance.name
+ else:
+ new_name = org_instance.name
+
+ return graph.as_graph_element(new_name, allow_tensor=True,
+ allow_operation=True)
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py
new file mode 100644
index 0000000000..68a3f90d26
--- /dev/null
+++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py
@@ -0,0 +1,110 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for contrib.copy_graph.python.util.copy."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib.framework.python.framework import tensor_util
+
+graph1 = tf.Graph()
+graph2 = tf.Graph()
+
+
+class CopyVariablesTest(tf.test.TestCase):
+
+ def testVariableCopy(self):
+
+ with graph1.as_default():
+ #Define a Variable in graph1
+ some_var = tf.Variable(2)
+ #Initialize session
+ sess1 = tf.Session()
+ #Initialize the Variable
+ tf.initialize_all_variables().run(session=sess1)
+
+ #Make a copy of some_var in the defsult scope in graph2
+ copy1 = tf.contrib.copy_graph.copy_variable_to_graph(
+ some_var, graph2)
+
+ #Make another copy with different scope
+ copy2 = tf.contrib.copy_graph.copy_variable_to_graph(
+ some_var, graph2, "test_scope")
+
+ #Initialize both the copies
+ with graph2.as_default():
+ #Initialize Session
+ sess2 = tf.Session()
+ #Initialize the Variables
+ tf.initialize_all_variables().run(session=sess2)
+
+ #Ensure values in all three variables are the same
+ v1 = some_var.eval(session=sess1)
+ v2 = copy1.eval(session=sess2)
+ v3 = copy2.eval(session=sess2)
+
+ assert isinstance(copy1, tf.Variable)
+ assert isinstance(copy2, tf.Variable)
+ assert v1 == v2 == v3 == 2
+
+
+class CopyOpsTest(tf.test.TestCase):
+
+ def testOpsCopy(self):
+
+ with graph1.as_default():
+ #Initialize a basic expression y = ax + b
+ x = tf.placeholder("float")
+ a = tf.Variable(3.0)
+ b = tf.constant(4.0)
+ ax = tf.mul(x, a)
+ y = tf.add(ax, b)
+ #Initialize session
+ sess1 = tf.Session()
+ #Initialize the Variable
+ tf.initialize_all_variables().run(session=sess1)
+
+ #First, initialize a as a Variable in graph2
+ a1 = tf.contrib.copy_graph.copy_variable_to_graph(
+ a, graph2)
+
+ #Initialize a1 in graph2
+ with graph2.as_default():
+ #Initialize session
+ sess2 = tf.Session()
+ #Initialize the Variable
+ tf.initialize_all_variables().run(session=sess2)
+
+ #Initialize a copy of y in graph2
+ y1 = tf.contrib.copy_graph.copy_op_to_graph(
+ y, graph2, [a1])
+
+ #Now that y has been copied, x must be copied too.
+ #Get that instance
+ x1 = tf.contrib.copy_graph.get_copied_op(x, graph2)
+
+ #Compare values of y & y1 for a sample input
+ #and check if they match
+ v1 = y.eval({x: 5}, session=sess1)
+ v2 = y1.eval({x1: 5}, session=sess2)
+
+ assert v1 == v2
+
+
+if __name__ == "__main__":
+ tf.test.main()