aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gunhan Gulsoy <gunan@google.com>2018-10-09 14:03:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 14:07:45 -0700
commit1f556d3a4172c30cf461e7e66334b70ffad2d559 (patch)
tree728c86acf4a1d7d49be8ccf34848c0687d97fa66
parent7b2f26280df8dee266d66e01a7ffac7a7eb25247 (diff)
Do not create a graph as a global variable in tests.
PiperOrigin-RevId: 216418324
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_test.py31
1 files changed, 19 insertions, 12 deletions
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py
index ba97c78456..4d8651a79f 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_test.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py
@@ -26,15 +26,16 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-graph1 = ops.Graph()
-graph2 = ops.Graph()
-
class CopyVariablesTest(test.TestCase):
+ def setUp(self):
+ self.graph1 = ops.Graph()
+ self.graph2 = ops.Graph()
+
def testVariableCopy(self):
- with graph1.as_default():
+ with self.graph1.as_default():
#Define a Variable in graph1
some_var = variables.VariableV1(2)
#Initialize session
@@ -43,13 +44,15 @@ class CopyVariablesTest(test.TestCase):
variables.global_variables_initializer().run(session=sess1)
#Make a copy of some_var in the defsult scope in graph2
- copy1 = copy_elements.copy_variable_to_graph(some_var, graph2)
+ copy1 = copy_elements.copy_variable_to_graph(some_var, self.graph2)
#Make another copy with different scope
- copy2 = copy_elements.copy_variable_to_graph(some_var, graph2, "test_scope")
+ copy2 = copy_elements.copy_variable_to_graph(some_var,
+ self.graph2,
+ "test_scope")
#Initialize both the copies
- with graph2.as_default():
+ with self.graph2.as_default():
#Initialize Session
sess2 = session_lib.Session()
#Initialize the Variables
@@ -67,9 +70,13 @@ class CopyVariablesTest(test.TestCase):
class CopyOpsTest(test.TestCase):
+ def setUp(self):
+ self.graph1 = ops.Graph()
+ self.graph2 = ops.Graph()
+
def testOpsCopy(self):
- with graph1.as_default():
+ with self.graph1.as_default():
#Initialize a basic expression y = ax + b
x = array_ops.placeholder("float")
a = variables.VariableV1(3.0)
@@ -82,21 +89,21 @@ class CopyOpsTest(test.TestCase):
variables.global_variables_initializer().run(session=sess1)
#First, initialize a as a Variable in graph2
- a1 = copy_elements.copy_variable_to_graph(a, graph2)
+ a1 = copy_elements.copy_variable_to_graph(a, self.graph2)
#Initialize a1 in graph2
- with graph2.as_default():
+ with self.graph2.as_default():
#Initialize session
sess2 = session_lib.Session()
#Initialize the Variable
variables.global_variables_initializer().run(session=sess2)
#Initialize a copy of y in graph2
- y1 = copy_elements.copy_op_to_graph(y, graph2, [a1])
+ y1 = copy_elements.copy_op_to_graph(y, self.graph2, [a1])
#Now that y has been copied, x must be copied too.
#Get that instance
- x1 = copy_elements.get_copied_op(x, graph2)
+ x1 = copy_elements.get_copied_op(x, self.graph2)
#Compare values of y & y1 for a sample input
#and check if they match