aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/copy_graph
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-01-25 12:02:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 12:07:22 -0800
commit351c0a533a111636333b4ebeede16485cf679ca9 (patch)
treea0786bc9a8fe7432d69d8095b10586e3ef515b93 /tensorflow/contrib/copy_graph
parenta8c4e8d96de7c0978851a5f9718bbd6b8056d862 (diff)
Add C0330 bad-continuation check to pylint.
PiperOrigin-RevId: 183270896
Diffstat (limited to 'tensorflow/contrib/copy_graph')
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py75
1 files changed, 34 insertions, 41 deletions
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index bae66ffd42..b806799202 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -35,10 +35,10 @@ from tensorflow.python.ops.variables import Variable
from tensorflow.python.client.session import Session
from tensorflow.python.framework import ops
-__all__ = ["copy_op_to_graph", "copy_variable_to_graph", "get_copied_op"]
+__all__ = ['copy_op_to_graph', 'copy_variable_to_graph', 'get_copied_op']
-def copy_variable_to_graph(org_instance, to_graph, scope=""):
+def copy_variable_to_graph(org_instance, to_graph, scope=''):
"""Given a `Variable` instance from one `Graph`, initializes and returns
a copy of it from another `Graph`, under the specified scope
(default `""`).
@@ -56,12 +56,11 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""):
"""
if not isinstance(org_instance, Variable):
- raise TypeError(str(org_instance) + " is not a Variable")
+ raise TypeError(str(org_instance) + ' is not a Variable')
#The name of the new variable
- if scope != "":
- new_name = (scope + '/' +
- org_instance.name[:org_instance.name.index(':')])
+ if scope != '':
+ new_name = (scope + '/' + org_instance.name[:org_instance.name.index(':')])
else:
new_name = org_instance.name[:org_instance.name.index(':')]
@@ -73,15 +72,15 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""):
for name, collection in org_instance.graph._collections.items():
if org_instance in collection:
if (name == ops.GraphKeys.GLOBAL_VARIABLES or
- name == ops.GraphKeys.TRAINABLE_VARIABLES or
- scope == ''):
+ name == ops.GraphKeys.TRAINABLE_VARIABLES or scope == ''):
collections.append(name)
else:
collections.append(scope + '/' + name)
#See if its trainable.
- trainable = (org_instance in org_instance.graph.get_collection(
- ops.GraphKeys.TRAINABLE_VARIABLES))
+ trainable = (
+ org_instance in org_instance.graph.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES))
#Get the initial value
with org_instance.graph.as_default():
temp_session = Session()
@@ -89,17 +88,17 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""):
#Initialize the new variable
with to_graph.as_default():
- new_var = Variable(init_value,
- trainable,
- name=new_name,
- collections=collections,
- validate_shape=False)
+ new_var = Variable(
+ init_value,
+ trainable,
+ name=new_name,
+ collections=collections,
+ validate_shape=False)
return new_var
-def copy_op_to_graph(org_instance, to_graph, variables,
- scope=""):
+def copy_op_to_graph(org_instance, to_graph, variables, scope=''):
"""Returns a copy of an operation from another Graph under a specified scope.
Given an `Operation` `org_instance` from one `Graph`,
@@ -139,14 +138,12 @@ def copy_op_to_graph(org_instance, to_graph, variables,
#If a variable by the new name already exists, return the
#correspondng tensor that will act as an input
if new_name in copied_variables:
- return to_graph.get_tensor_by_name(
- copied_variables[new_name].name)
+ return to_graph.get_tensor_by_name(copied_variables[new_name].name)
#If an instance of the same name exists, return appropriately
try:
- already_present = to_graph.as_graph_element(new_name,
- allow_tensor=True,
- allow_operation=True)
+ already_present = to_graph.as_graph_element(
+ new_name, allow_tensor=True, allow_operation=True)
return already_present
except:
pass
@@ -184,20 +181,21 @@ def copy_op_to_graph(org_instance, to_graph, variables,
#If it has an original_op parameter, copy it
if op._original_op is not None:
- new_original_op = copy_op_to_graph(op._original_op, to_graph,
- variables, scope)
+ new_original_op = copy_op_to_graph(op._original_op, to_graph, variables,
+ scope)
else:
new_original_op = None
#If it has control inputs, call this function recursively on each.
- new_control_inputs = [copy_op_to_graph(x, to_graph, variables,
- scope)
- for x in op.control_inputs]
+ new_control_inputs = [
+ copy_op_to_graph(x, to_graph, variables, scope)
+ for x in op.control_inputs
+ ]
#If it has inputs, call this function recursively on each.
- new_inputs = [copy_op_to_graph(x, to_graph, variables,
- scope)
- for x in op.inputs]
+ new_inputs = [
+ copy_op_to_graph(x, to_graph, variables, scope) for x in op.inputs
+ ]
#Make a new node_def based on that of the original.
#An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it
@@ -216,13 +214,8 @@ def copy_op_to_graph(org_instance, to_graph, variables,
op_def = deepcopy(op._op_def)
#Initialize a new Operation instance
- new_op = ops.Operation(new_node_def,
- to_graph,
- new_inputs,
- output_types,
- new_control_inputs,
- input_types,
- new_original_op,
+ new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types,
+ new_control_inputs, input_types, new_original_op,
op_def)
#Use Graph's hidden methods to add the op
to_graph._add_op(new_op) # pylint: disable=protected-access
@@ -233,10 +226,10 @@ def copy_op_to_graph(org_instance, to_graph, variables,
return new_op
else:
- raise TypeError("Could not copy instance: " + str(org_instance))
+ raise TypeError('Could not copy instance: ' + str(org_instance))
-def get_copied_op(org_instance, graph, scope=""):
+def get_copied_op(org_instance, graph, scope=''):
"""Given an `Operation` instance from some `Graph`, returns
its namesake from `graph`, under the specified scope
(default `""`).
@@ -259,5 +252,5 @@ def get_copied_op(org_instance, graph, scope=""):
else:
new_name = org_instance.name
- return graph.as_graph_element(new_name, allow_tensor=True,
- allow_operation=True)
+ return graph.as_graph_element(
+ new_name, allow_tensor=True, allow_operation=True)