aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-09-27 13:18:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 13:23:04 -0700
commit4cedc8b6e738b7a188c9c091cf667bacafae44b7 (patch)
tree56de35940e5f9daedd5f39a82d2cd90cf374e4e4 /tensorflow/python/grappler
parentc898e63d07fc63315be98f0772736e5d7f2fb44c (diff)
Updating the V2 variables API.
PiperOrigin-RevId: 214824023
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/item_test.py2
-rw-r--r--tensorflow/python/grappler/memory_optimizer_test.py10
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py2
3 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py
index c40de9da0a..d3d96c646c 100644
--- a/tensorflow/python/grappler/item_test.py
+++ b/tensorflow/python/grappler/item_test.py
@@ -110,7 +110,7 @@ class ItemTest(test.TestCase):
def testColocationContraints(self):
with ops.Graph().as_default() as g:
c = constant_op.constant([10])
- v = variables.Variable([3], dtype=dtypes.int32)
+ v = variables.VariableV1([3], dtype=dtypes.int32)
i = gen_array_ops.ref_identity(v)
a = state_ops.assign(i, c)
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py
index b658edff2d..03b42f6453 100644
--- a/tensorflow/python/grappler/memory_optimizer_test.py
+++ b/tensorflow/python/grappler/memory_optimizer_test.py
@@ -39,8 +39,8 @@ class MemoryOptimizerSwapTest(test.TestCase):
def testNoSwapping(self):
"""Make sure the graph is preserved when there is nothing to swap."""
- a = variables.Variable(10, name='a')
- b = variables.Variable(20, name='b')
+ a = variables.VariableV1(10, name='a')
+ b = variables.VariableV1(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
@@ -60,8 +60,8 @@ class MemoryOptimizerSwapTest(test.TestCase):
def testSimpleSwap(self):
"""Check that the swap annotations are followed."""
- a = variables.Variable(10, name='a')
- b = variables.Variable(20, name='b')
+ a = variables.VariableV1(10, name='a')
+ b = variables.VariableV1(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
@@ -244,7 +244,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
init_op_name=init_op_name,
train_op_name=train_op_name,
loss_op_name=loss_op_name)
- self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4)
+ self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-2)
def _annotated_graph(self):
graph = ops.Graph()
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index 5a9afe7257..eca0f67982 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -57,7 +57,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
def testKeepNodes(self):
g = ops.Graph()
with g.as_default():
- a1 = variables.Variable(
+ a1 = variables.VariableV1(
1.0) # Must be preserved since it's in the collection 'variables'.
a2 = constant_op.constant(0, shape=[50, 50], name='keep')
ops.add_to_collection('a2', a2) # Explicitly add to collection.