aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-09-27 13:18:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 13:23:04 -0700
commit4cedc8b6e738b7a188c9c091cf667bacafae44b7 (patch)
tree56de35940e5f9daedd5f39a82d2cd90cf374e4e4 /tensorflow/contrib/framework
parentc898e63d07fc63315be98f0772736e5d7f2fb44c (diff)
Updating the V2 variables API.
PiperOrigin-RevId: 214824023
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index f9b0efd1da..c223df5b6e 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -192,7 +192,7 @@ class GlobalStepTest(test.TestCase):
def test_invalid_dtype(self):
with ops.Graph().as_default() as g:
self.assertEquals(None, variables_lib2.get_global_step())
- variables_lib.Variable(
+ variables_lib.VariableV1(
0.0,
trainable=False,
dtype=dtypes.float32,
@@ -205,7 +205,7 @@ class GlobalStepTest(test.TestCase):
def test_invalid_shape(self):
with ops.Graph().as_default() as g:
self.assertEquals(None, variables_lib2.get_global_step())
- variables_lib.Variable(
+ variables_lib.VariableV1(
[0],
trainable=False,
dtype=dtypes.int32,
@@ -229,7 +229,7 @@ class GlobalStepTest(test.TestCase):
def test_get_global_step(self):
with ops.Graph().as_default() as g:
self.assertEquals(None, variables_lib2.get_global_step())
- variables_lib.Variable(
+ variables_lib.VariableV1(
0,
trainable=False,
dtype=dtypes.int32,
@@ -607,10 +607,10 @@ class ModelVariablesTest(test.TestCase):
with self.cached_session():
with variable_scope.variable_scope('A'):
variables_lib2.local_variable([5])
- a = variables_lib.Variable([5])
+ a = variables_lib.VariableV1([5])
with variable_scope.variable_scope('B'):
variables_lib2.local_variable([5])
- b = variables_lib.Variable([5])
+ b = variables_lib.VariableV1([5])
self.assertEquals([a], variables_lib2.get_trainable_variables('A'))
self.assertEquals([b], variables_lib2.get_trainable_variables('B'))
@@ -953,7 +953,7 @@ class AssignFromCheckpointTest(test.TestCase):
# Create a set of variables to save in the checkpoint.
for var_name in var_names_to_values:
var_value = var_names_to_values[var_name]
- var_list.append(variables_lib.Variable(var_value, name=var_name))
+ var_list.append(variables_lib.VariableV1(var_value, name=var_name))
saver = saver_lib.Saver(var_list)
init_op = variables_lib.variables_initializer(var_list)
sess.run(init_op)
@@ -1106,7 +1106,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
# Create a set of variables to save in the checkpoint.
for var_name in var_names_to_values:
var_value = var_names_to_values[var_name]
- var_list.append(variables_lib.Variable(var_value, name=var_name))
+ var_list.append(variables_lib.VariableV1(var_value, name=var_name))
saver = saver_lib.Saver(var_list)
init_op = variables_lib.variables_initializer(var_list)
sess.run(init_op)
@@ -1297,7 +1297,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
class ZeroInitializerOpTest(test.TestCase):
def _testZeroInitializer(self, shape, initializer, use_init):
- var = variables_lib.Variable(initializer)
+ var = variables_lib.VariableV1(initializer)
var_zero = variables_lib2.zero_initializer(var)
with self.cached_session() as sess:
with self.assertRaisesOpError('Attempting to use uninitialized value'):
@@ -1350,12 +1350,12 @@ class FilterVariablesTest(test.TestCase):
g = ops.Graph()
with g.as_default():
var_list = []
- var_list.append(variables_lib.Variable(0, name='conv1/weights'))
- var_list.append(variables_lib.Variable(0, name='conv1/biases'))
- var_list.append(variables_lib.Variable(0, name='conv2/weights'))
- var_list.append(variables_lib.Variable(0, name='conv2/biases'))
- var_list.append(variables_lib.Variable(0, name='clfs/weights'))
- var_list.append(variables_lib.Variable(0, name='clfs/biases'))
+ var_list.append(variables_lib.VariableV1(0, name='conv1/weights'))
+ var_list.append(variables_lib.VariableV1(0, name='conv1/biases'))
+ var_list.append(variables_lib.VariableV1(0, name='conv2/weights'))
+ var_list.append(variables_lib.VariableV1(0, name='conv2/biases'))
+ var_list.append(variables_lib.VariableV1(0, name='clfs/weights'))
+ var_list.append(variables_lib.VariableV1(0, name='clfs/biases'))
self._var_list = var_list
def _test_filter_variables(self,