aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/copy_graph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-11 21:55:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-11 23:03:15 -0700
commit45e6676b1acb99a19052459a23ffb5437742d8e9 (patch)
treea96fd7984c7c5677a00e9a654bce2118bd75c0bc /tensorflow/contrib/copy_graph
parentfe5a2833133645045d522f5daa4813974b86e751 (diff)
Add pylint indentation check to sanity and fix existing indentation
Change: 132840696
Diffstat (limited to 'tensorflow/contrib/copy_graph')
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py424
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_test.py120
2 files changed, 272 insertions, 272 deletions
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index 3c80a17633..a45620e773 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -39,223 +39,223 @@ __all__ = ["copy_op_to_graph", "copy_variable_to_graph", "get_copied_op"]
def copy_variable_to_graph(org_instance, to_graph, scope=""):
- """Given a `Variable` instance from one `Graph`, initializes and returns
- a copy of it from another `Graph`, under the specified scope
- (default `""`).
-
- Args:
- org_instance: A `Variable` from some `Graph`.
- to_graph: The `Graph` to copy the `Variable` to.
- scope: A scope for the new `Variable` (default `""`).
-
- Returns:
- The copied `Variable` from `to_graph`.
-
- Raises:
- TypeError: If `org_instance` is not a `Variable`.
- """
-
- if not isinstance(org_instance, Variable):
- raise TypeError(str(org_instance) + " is not a Variable")
-
- #The name of the new variable
- if scope != "":
- new_name = (scope + '/' +
- org_instance.name[:org_instance.name.index(':')])
- else:
- new_name = org_instance.name[:org_instance.name.index(':')]
-
- #Get the collections that the new instance needs to be added to.
- #The new collections will also be a part of the given scope,
- #except the special ones required for variable initialization and
- #training.
- collections = []
- for name, collection in org_instance.graph._collections.items():
- if org_instance in collection:
- if (name == ops.GraphKeys.VARIABLES or
- name == ops.GraphKeys.TRAINABLE_VARIABLES or
- scope == ''):
- collections.append(name)
- else:
- collections.append(scope + '/' + name)
-
- #See if its trainable.
- trainable = (org_instance in org_instance.graph.get_collection(
- ops.GraphKeys.TRAINABLE_VARIABLES))
- #Get the initial value
- with org_instance.graph.as_default():
- temp_session = Session()
- init_value = temp_session.run(org_instance.initialized_value())
-
- #Initialize the new variable
- with to_graph.as_default():
- new_var = Variable(init_value,
- trainable,
- name=new_name,
- collections=collections,
- validate_shape=False)
-
- return new_var
+ """Given a `Variable` instance from one `Graph`, initializes and returns
+ a copy of it from another `Graph`, under the specified scope
+ (default `""`).
+
+ Args:
+ org_instance: A `Variable` from some `Graph`.
+ to_graph: The `Graph` to copy the `Variable` to.
+ scope: A scope for the new `Variable` (default `""`).
+
+ Returns:
+ The copied `Variable` from `to_graph`.
+
+ Raises:
+ TypeError: If `org_instance` is not a `Variable`.
+ """
+
+ if not isinstance(org_instance, Variable):
+ raise TypeError(str(org_instance) + " is not a Variable")
+
+ #The name of the new variable
+ if scope != "":
+ new_name = (scope + '/' +
+ org_instance.name[:org_instance.name.index(':')])
+ else:
+ new_name = org_instance.name[:org_instance.name.index(':')]
+
+ #Get the collections that the new instance needs to be added to.
+ #The new collections will also be a part of the given scope,
+ #except the special ones required for variable initialization and
+ #training.
+ collections = []
+ for name, collection in org_instance.graph._collections.items():
+ if org_instance in collection:
+ if (name == ops.GraphKeys.VARIABLES or
+ name == ops.GraphKeys.TRAINABLE_VARIABLES or
+ scope == ''):
+ collections.append(name)
+ else:
+ collections.append(scope + '/' + name)
+
+ #See if its trainable.
+ trainable = (org_instance in org_instance.graph.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES))
+ #Get the initial value
+ with org_instance.graph.as_default():
+ temp_session = Session()
+ init_value = temp_session.run(org_instance.initialized_value())
+
+ #Initialize the new variable
+ with to_graph.as_default():
+ new_var = Variable(init_value,
+ trainable,
+ name=new_name,
+ collections=collections,
+ validate_shape=False)
+
+ return new_var
def copy_op_to_graph(org_instance, to_graph, variables,
scope=""):
- """Given an `Operation` 'org_instance` from one `Graph`,
- initializes and returns a copy of it from another `Graph`,
- under the specified scope (default `""`).
-
- The copying is done recursively, so any `Operation` whose output
- is required to evaluate the `org_instance`, is also copied (unless
- already done).
-
- Since `Variable` instances are copied separately, those required
- to evaluate `org_instance` must be provided as input.
-
- Args:
- org_instance: An `Operation` from some `Graph`. Could be a
- `Placeholder` as well.
- to_graph: The `Graph` to copy `org_instance` to.
- variables: An iterable of `Variable` instances to copy `org_instance` to.
- scope: A scope for the new `Variable` (default `""`).
-
- Returns:
- The copied `Operation` from `to_graph`.
-
- Raises:
- TypeError: If `org_instance` is not an `Operation` or `Tensor`.
- """
-
- #The name of the new instance
- if scope != '':
- new_name = scope + '/' + org_instance.name
- else:
- new_name = org_instance.name
-
- #Extract names of variables
- copied_variables = dict((x.name, x) for x in variables)
-
- #If a variable by the new name already exists, return the
- #correspondng tensor that will act as an input
- if new_name in copied_variables:
- return to_graph.get_tensor_by_name(
- copied_variables[new_name].name)
-
- #If an instance of the same name exists, return appropriately
- try:
- already_present = to_graph.as_graph_element(new_name,
- allow_tensor=True,
- allow_operation=True)
- return already_present
- except:
- pass
-
- #Get the collections that the new instance needs to be added to.
- #The new collections will also be a part of the given scope.
- collections = []
- for name, collection in org_instance.graph._collections.items():
- if org_instance in collection:
- if scope == '':
- collections.append(name)
- else:
- collections.append(scope + '/' + name)
-
- #Take action based on the class of the instance
-
- if isinstance(org_instance, ops.Tensor):
-
- #If its a Tensor, it is one of the outputs of the underlying
- #op. Therefore, copy the op itself and return the appropriate
- #output.
- op = org_instance.op
- new_op = copy_op_to_graph(op, to_graph, variables, scope)
- output_index = op.outputs.index(org_instance)
- new_tensor = new_op.outputs[output_index]
- #Add to collections if any
- for collection in collections:
- to_graph.add_to_collection(collection, new_tensor)
-
- return new_tensor
-
- elif isinstance(org_instance, ops.Operation):
-
- op = org_instance
-
- #If it has an original_op parameter, copy it
- if op._original_op is not None:
- new_original_op = copy_op_to_graph(op._original_op, to_graph,
- variables, scope)
- else:
- new_original_op = None
-
- #If it has control inputs, call this function recursively on each.
- new_control_inputs = [copy_op_to_graph(x, to_graph, variables,
- scope)
- for x in op.control_inputs]
-
- #If it has inputs, call this function recursively on each.
- new_inputs = [copy_op_to_graph(x, to_graph, variables,
- scope)
- for x in op.inputs]
-
- #Make a new node_def based on that of the original.
- #An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it
- #stores String-based info such as name, device and type of the op.
- #Unique to every Operation instance.
- new_node_def = deepcopy(op._node_def)
- #Change the name
- new_node_def.name = new_name
-
- #Copy the other inputs needed for initialization
- output_types = op._output_types[:]
- input_types = op._input_types[:]
-
- #Make a copy of the op_def too.
- #Its unique to every _type_ of Operation.
- op_def = deepcopy(op._op_def)
-
- #Initialize a new Operation instance
- new_op = ops.Operation(new_node_def,
- to_graph,
- new_inputs,
- output_types,
- new_control_inputs,
- input_types,
- new_original_op,
- op_def)
- #Use Graph's hidden methods to add the op
- to_graph._add_op(new_op)
- to_graph._record_op_seen_by_control_dependencies(new_op)
- for device_function in reversed(to_graph._device_function_stack):
- new_op._set_device(device_function(new_op))
-
- return new_op
-
+ """Given an `Operation` 'org_instance` from one `Graph`,
+ initializes and returns a copy of it from another `Graph`,
+ under the specified scope (default `""`).
+
+ The copying is done recursively, so any `Operation` whose output
+ is required to evaluate the `org_instance`, is also copied (unless
+ already done).
+
+ Since `Variable` instances are copied separately, those required
+ to evaluate `org_instance` must be provided as input.
+
+ Args:
+ org_instance: An `Operation` from some `Graph`. Could be a
+ `Placeholder` as well.
+ to_graph: The `Graph` to copy `org_instance` to.
+ variables: An iterable of `Variable` instances to copy `org_instance` to.
+ scope: A scope for the new `Variable` (default `""`).
+
+ Returns:
+ The copied `Operation` from `to_graph`.
+
+ Raises:
+ TypeError: If `org_instance` is not an `Operation` or `Tensor`.
+ """
+
+ #The name of the new instance
+ if scope != '':
+ new_name = scope + '/' + org_instance.name
+ else:
+ new_name = org_instance.name
+
+ #Extract names of variables
+ copied_variables = dict((x.name, x) for x in variables)
+
+ #If a variable by the new name already exists, return the
+ #correspondng tensor that will act as an input
+ if new_name in copied_variables:
+ return to_graph.get_tensor_by_name(
+ copied_variables[new_name].name)
+
+ #If an instance of the same name exists, return appropriately
+ try:
+ already_present = to_graph.as_graph_element(new_name,
+ allow_tensor=True,
+ allow_operation=True)
+ return already_present
+ except:
+ pass
+
+ #Get the collections that the new instance needs to be added to.
+ #The new collections will also be a part of the given scope.
+ collections = []
+ for name, collection in org_instance.graph._collections.items():
+ if org_instance in collection:
+ if scope == '':
+ collections.append(name)
+ else:
+ collections.append(scope + '/' + name)
+
+ #Take action based on the class of the instance
+
+ if isinstance(org_instance, ops.Tensor):
+
+ #If its a Tensor, it is one of the outputs of the underlying
+ #op. Therefore, copy the op itself and return the appropriate
+ #output.
+ op = org_instance.op
+ new_op = copy_op_to_graph(op, to_graph, variables, scope)
+ output_index = op.outputs.index(org_instance)
+ new_tensor = new_op.outputs[output_index]
+ #Add to collections if any
+ for collection in collections:
+ to_graph.add_to_collection(collection, new_tensor)
+
+ return new_tensor
+
+ elif isinstance(org_instance, ops.Operation):
+
+ op = org_instance
+
+ #If it has an original_op parameter, copy it
+ if op._original_op is not None:
+ new_original_op = copy_op_to_graph(op._original_op, to_graph,
+ variables, scope)
else:
- raise TypeError("Could not copy instance: " + str(org_instance))
+ new_original_op = None
+
+ #If it has control inputs, call this function recursively on each.
+ new_control_inputs = [copy_op_to_graph(x, to_graph, variables,
+ scope)
+ for x in op.control_inputs]
+
+ #If it has inputs, call this function recursively on each.
+ new_inputs = [copy_op_to_graph(x, to_graph, variables,
+ scope)
+ for x in op.inputs]
+
+ #Make a new node_def based on that of the original.
+ #An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it
+ #stores String-based info such as name, device and type of the op.
+ #Unique to every Operation instance.
+ new_node_def = deepcopy(op._node_def)
+ #Change the name
+ new_node_def.name = new_name
+
+ #Copy the other inputs needed for initialization
+ output_types = op._output_types[:]
+ input_types = op._input_types[:]
+
+ #Make a copy of the op_def too.
+ #Its unique to every _type_ of Operation.
+ op_def = deepcopy(op._op_def)
+
+ #Initialize a new Operation instance
+ new_op = ops.Operation(new_node_def,
+ to_graph,
+ new_inputs,
+ output_types,
+ new_control_inputs,
+ input_types,
+ new_original_op,
+ op_def)
+ #Use Graph's hidden methods to add the op
+ to_graph._add_op(new_op)
+ to_graph._record_op_seen_by_control_dependencies(new_op)
+ for device_function in reversed(to_graph._device_function_stack):
+ new_op._set_device(device_function(new_op))
+
+ return new_op
+
+ else:
+ raise TypeError("Could not copy instance: " + str(org_instance))
def get_copied_op(org_instance, graph, scope=""):
- """Given an `Operation` instance from some `Graph`, returns
- its namesake from `graph`, under the specified scope
- (default `""`).
-
- If a copy of `org_instance` is present in `graph` under the given
- `scope`, it will be returned.
-
- Args:
- org_instance: An `Operation` from some `Graph`.
- graph: The `Graph` to be searched for a copr of `org_instance`.
- scope: The scope `org_instance` is present in.
-
- Returns:
- The `Operation` copy from `graph`.
- """
-
- #The name of the copied instance
- if scope != '':
- new_name = scope + '/' + org_instance.name
- else:
- new_name = org_instance.name
-
- return graph.as_graph_element(new_name, allow_tensor=True,
- allow_operation=True)
+ """Given an `Operation` instance from some `Graph`, returns
+ its namesake from `graph`, under the specified scope
+ (default `""`).
+
+ If a copy of `org_instance` is present in `graph` under the given
+ `scope`, it will be returned.
+
+ Args:
+ org_instance: An `Operation` from some `Graph`.
+ graph: The `Graph` to be searched for a copr of `org_instance`.
+ scope: The scope `org_instance` is present in.
+
+ Returns:
+ The `Operation` copy from `graph`.
+ """
+
+ #The name of the copied instance
+ if scope != '':
+ new_name = scope + '/' + org_instance.name
+ else:
+ new_name = org_instance.name
+
+ return graph.as_graph_element(new_name, allow_tensor=True,
+ allow_operation=True)
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py
index d812ad1567..0f5a9f04cf 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_test.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py
@@ -28,83 +28,83 @@ graph2 = tf.Graph()
class CopyVariablesTest(tf.test.TestCase):
- def testVariableCopy(self):
+ def testVariableCopy(self):
- with graph1.as_default():
- #Define a Variable in graph1
- some_var = tf.Variable(2)
- #Initialize session
- sess1 = tf.Session()
- #Initialize the Variable
- tf.initialize_all_variables().run(session=sess1)
+ with graph1.as_default():
+ #Define a Variable in graph1
+ some_var = tf.Variable(2)
+ #Initialize session
+ sess1 = tf.Session()
+ #Initialize the Variable
+ tf.initialize_all_variables().run(session=sess1)
- #Make a copy of some_var in the defsult scope in graph2
- copy1 = tf.contrib.copy_graph.copy_variable_to_graph(
- some_var, graph2)
+ #Make a copy of some_var in the defsult scope in graph2
+ copy1 = tf.contrib.copy_graph.copy_variable_to_graph(
+ some_var, graph2)
- #Make another copy with different scope
- copy2 = tf.contrib.copy_graph.copy_variable_to_graph(
- some_var, graph2, "test_scope")
+ #Make another copy with different scope
+ copy2 = tf.contrib.copy_graph.copy_variable_to_graph(
+ some_var, graph2, "test_scope")
- #Initialize both the copies
- with graph2.as_default():
- #Initialize Session
- sess2 = tf.Session()
- #Initialize the Variables
- tf.initialize_all_variables().run(session=sess2)
+ #Initialize both the copies
+ with graph2.as_default():
+ #Initialize Session
+ sess2 = tf.Session()
+ #Initialize the Variables
+ tf.initialize_all_variables().run(session=sess2)
- #Ensure values in all three variables are the same
- v1 = some_var.eval(session=sess1)
- v2 = copy1.eval(session=sess2)
- v3 = copy2.eval(session=sess2)
+ #Ensure values in all three variables are the same
+ v1 = some_var.eval(session=sess1)
+ v2 = copy1.eval(session=sess2)
+ v3 = copy2.eval(session=sess2)
- assert isinstance(copy1, tf.Variable)
- assert isinstance(copy2, tf.Variable)
- assert v1 == v2 == v3 == 2
+ assert isinstance(copy1, tf.Variable)
+ assert isinstance(copy2, tf.Variable)
+ assert v1 == v2 == v3 == 2
class CopyOpsTest(tf.test.TestCase):
- def testOpsCopy(self):
+ def testOpsCopy(self):
- with graph1.as_default():
- #Initialize a basic expression y = ax + b
- x = tf.placeholder("float")
- a = tf.Variable(3.0)
- b = tf.constant(4.0)
- ax = tf.mul(x, a)
- y = tf.add(ax, b)
- #Initialize session
- sess1 = tf.Session()
- #Initialize the Variable
- tf.initialize_all_variables().run(session=sess1)
+ with graph1.as_default():
+ #Initialize a basic expression y = ax + b
+ x = tf.placeholder("float")
+ a = tf.Variable(3.0)
+ b = tf.constant(4.0)
+ ax = tf.mul(x, a)
+ y = tf.add(ax, b)
+ #Initialize session
+ sess1 = tf.Session()
+ #Initialize the Variable
+ tf.initialize_all_variables().run(session=sess1)
- #First, initialize a as a Variable in graph2
- a1 = tf.contrib.copy_graph.copy_variable_to_graph(
- a, graph2)
+ #First, initialize a as a Variable in graph2
+ a1 = tf.contrib.copy_graph.copy_variable_to_graph(
+ a, graph2)
- #Initialize a1 in graph2
- with graph2.as_default():
- #Initialize session
- sess2 = tf.Session()
- #Initialize the Variable
- tf.initialize_all_variables().run(session=sess2)
+ #Initialize a1 in graph2
+ with graph2.as_default():
+ #Initialize session
+ sess2 = tf.Session()
+ #Initialize the Variable
+ tf.initialize_all_variables().run(session=sess2)
- #Initialize a copy of y in graph2
- y1 = tf.contrib.copy_graph.copy_op_to_graph(
- y, graph2, [a1])
+ #Initialize a copy of y in graph2
+ y1 = tf.contrib.copy_graph.copy_op_to_graph(
+ y, graph2, [a1])
- #Now that y has been copied, x must be copied too.
- #Get that instance
- x1 = tf.contrib.copy_graph.get_copied_op(x, graph2)
+ #Now that y has been copied, x must be copied too.
+ #Get that instance
+ x1 = tf.contrib.copy_graph.get_copied_op(x, graph2)
- #Compare values of y & y1 for a sample input
- #and check if they match
- v1 = y.eval({x: 5}, session=sess1)
- v2 = y1.eval({x1: 5}, session=sess2)
+ #Compare values of y & y1 for a sample input
+ #and check if they match
+ v1 = y.eval({x: 5}, session=sess1)
+ v2 = y1.eval({x1: 5}, session=sess2)
- assert v1 == v2
+ assert v1 == v2
if __name__ == "__main__":
- tf.test.main()
+ tf.test.main()