diff options
author | Alexandre Passos <apassos@google.com> | 2018-09-27 13:18:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 13:23:04 -0700 |
commit | 4cedc8b6e738b7a188c9c091cf667bacafae44b7 (patch) | |
tree | 56de35940e5f9daedd5f39a82d2cd90cf374e4e4 /tensorflow/python/grappler | |
parent | c898e63d07fc63315be98f0772736e5d7f2fb44c (diff) |
Updating the V2 variables API.
PiperOrigin-RevId: 214824023
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/item_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/grappler/memory_optimizer_test.py | 10 | ||||
-rw-r--r-- | tensorflow/python/grappler/tf_optimizer_test.py | 2 |
3 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py index c40de9da0a..d3d96c646c 100644 --- a/tensorflow/python/grappler/item_test.py +++ b/tensorflow/python/grappler/item_test.py @@ -110,7 +110,7 @@ class ItemTest(test.TestCase): def testColocationContraints(self): with ops.Graph().as_default() as g: c = constant_op.constant([10]) - v = variables.Variable([3], dtype=dtypes.int32) + v = variables.VariableV1([3], dtype=dtypes.int32) i = gen_array_ops.ref_identity(v) a = state_ops.assign(i, c) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py index b658edff2d..03b42f6453 100644 --- a/tensorflow/python/grappler/memory_optimizer_test.py +++ b/tensorflow/python/grappler/memory_optimizer_test.py @@ -39,8 +39,8 @@ class MemoryOptimizerSwapTest(test.TestCase): def testNoSwapping(self): """Make sure the graph is preserved when there is nothing to swap.""" - a = variables.Variable(10, name='a') - b = variables.Variable(20, name='b') + a = variables.VariableV1(10, name='a') + b = variables.VariableV1(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) @@ -60,8 +60,8 @@ class MemoryOptimizerSwapTest(test.TestCase): def testSimpleSwap(self): """Check that the swap annotations are followed.""" - a = variables.Variable(10, name='a') - b = variables.Variable(20, name='b') + a = variables.VariableV1(10, name='a') + b = variables.VariableV1(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) @@ -244,7 +244,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase): init_op_name=init_op_name, train_op_name=train_op_name, loss_op_name=loss_op_name) - self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4) + self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-2) def _annotated_graph(self): graph = ops.Graph() diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py index 5a9afe7257..eca0f67982 100644 --- a/tensorflow/python/grappler/tf_optimizer_test.py +++ b/tensorflow/python/grappler/tf_optimizer_test.py @@ -57,7 +57,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): def testKeepNodes(self): g = ops.Graph() with g.as_default(): - a1 = variables.Variable( + a1 = variables.VariableV1( 1.0) # Must be preserved since it's in the collection 'variables'. a2 = constant_op.constant(0, shape=[50, 50], name='keep') ops.add_to_collection('a2', a2) # Explicitly add to collection. |