From 1f556d3a4172c30cf461e7e66334b70ffad2d559 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Tue, 9 Oct 2018 14:03:23 -0700 Subject: Do not create a graph as a global variable in tests. PiperOrigin-RevId: 216418324 --- .../contrib/copy_graph/python/util/copy_test.py | 31 +++++++++++++--------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py index ba97c78456..4d8651a79f 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_test.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py @@ -26,15 +26,16 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -graph1 = ops.Graph() -graph2 = ops.Graph() - class CopyVariablesTest(test.TestCase): + def setUp(self): + self.graph1 = ops.Graph() + self.graph2 = ops.Graph() + def testVariableCopy(self): - with graph1.as_default(): + with self.graph1.as_default(): #Define a Variable in graph1 some_var = variables.VariableV1(2) #Initialize session @@ -43,13 +44,15 @@ class CopyVariablesTest(test.TestCase): variables.global_variables_initializer().run(session=sess1) #Make a copy of some_var in the defsult scope in graph2 - copy1 = copy_elements.copy_variable_to_graph(some_var, graph2) + copy1 = copy_elements.copy_variable_to_graph(some_var, self.graph2) #Make another copy with different scope - copy2 = copy_elements.copy_variable_to_graph(some_var, graph2, "test_scope") + copy2 = copy_elements.copy_variable_to_graph(some_var, + self.graph2, + "test_scope") #Initialize both the copies - with graph2.as_default(): + with self.graph2.as_default(): #Initialize Session sess2 = session_lib.Session() #Initialize the Variables @@ -67,9 +70,13 @@ class CopyVariablesTest(test.TestCase): class CopyOpsTest(test.TestCase): + def setUp(self): + self.graph1 = ops.Graph() + self.graph2 = ops.Graph() + def testOpsCopy(self): - with graph1.as_default(): + with self.graph1.as_default(): #Initialize a basic expression y = ax + b x = array_ops.placeholder("float") a = variables.VariableV1(3.0) @@ -82,21 +89,21 @@ class CopyOpsTest(test.TestCase): variables.global_variables_initializer().run(session=sess1) #First, initialize a as a Variable in graph2 - a1 = copy_elements.copy_variable_to_graph(a, graph2) + a1 = copy_elements.copy_variable_to_graph(a, self.graph2) #Initialize a1 in graph2 - with graph2.as_default(): + with self.graph2.as_default(): #Initialize session sess2 = session_lib.Session() #Initialize the Variable variables.global_variables_initializer().run(session=sess2) #Initialize a copy of y in graph2 - y1 = copy_elements.copy_op_to_graph(y, graph2, [a1]) + y1 = copy_elements.copy_op_to_graph(y, self.graph2, [a1]) #Now that y has been copied, x must be copied too. #Get that instance - x1 = copy_elements.get_copied_op(x, graph2) + x1 = copy_elements.get_copied_op(x, self.graph2) #Compare values of y & y1 for a sample input #and check if they match -- cgit v1.2.3