diff options
author | Alexandre Passos <apassos@google.com> | 2018-09-27 13:18:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 13:23:04 -0700 |
commit | 4cedc8b6e738b7a188c9c091cf667bacafae44b7 (patch) | |
tree | 56de35940e5f9daedd5f39a82d2cd90cf374e4e4 /tensorflow/contrib/framework | |
parent | c898e63d07fc63315be98f0772736e5d7f2fb44c (diff) |
Updating the V2 variables API.
PiperOrigin-RevId: 214824023
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r-- | tensorflow/contrib/framework/python/ops/variables_test.py | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index f9b0efd1da..c223df5b6e 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -192,7 +192,7 @@ class GlobalStepTest(test.TestCase): def test_invalid_dtype(self): with ops.Graph().as_default() as g: self.assertEquals(None, variables_lib2.get_global_step()) - variables_lib.Variable( + variables_lib.VariableV1( 0.0, trainable=False, dtype=dtypes.float32, @@ -205,7 +205,7 @@ class GlobalStepTest(test.TestCase): def test_invalid_shape(self): with ops.Graph().as_default() as g: self.assertEquals(None, variables_lib2.get_global_step()) - variables_lib.Variable( + variables_lib.VariableV1( [0], trainable=False, dtype=dtypes.int32, @@ -229,7 +229,7 @@ class GlobalStepTest(test.TestCase): def test_get_global_step(self): with ops.Graph().as_default() as g: self.assertEquals(None, variables_lib2.get_global_step()) - variables_lib.Variable( + variables_lib.VariableV1( 0, trainable=False, dtype=dtypes.int32, @@ -607,10 +607,10 @@ class ModelVariablesTest(test.TestCase): with self.cached_session(): with variable_scope.variable_scope('A'): variables_lib2.local_variable([5]) - a = variables_lib.Variable([5]) + a = variables_lib.VariableV1([5]) with variable_scope.variable_scope('B'): variables_lib2.local_variable([5]) - b = variables_lib.Variable([5]) + b = variables_lib.VariableV1([5]) self.assertEquals([a], variables_lib2.get_trainable_variables('A')) self.assertEquals([b], variables_lib2.get_trainable_variables('B')) @@ -953,7 +953,7 @@ class AssignFromCheckpointTest(test.TestCase): # Create a set of variables to save in the checkpoint. for var_name in var_names_to_values: var_value = var_names_to_values[var_name] - var_list.append(variables_lib.Variable(var_value, name=var_name)) + var_list.append(variables_lib.VariableV1(var_value, name=var_name)) saver = saver_lib.Saver(var_list) init_op = variables_lib.variables_initializer(var_list) sess.run(init_op) @@ -1106,7 +1106,7 @@ class AssignFromCheckpointFnTest(test.TestCase): # Create a set of variables to save in the checkpoint. for var_name in var_names_to_values: var_value = var_names_to_values[var_name] - var_list.append(variables_lib.Variable(var_value, name=var_name)) + var_list.append(variables_lib.VariableV1(var_value, name=var_name)) saver = saver_lib.Saver(var_list) init_op = variables_lib.variables_initializer(var_list) sess.run(init_op) @@ -1297,7 +1297,7 @@ class AssignFromCheckpointFnTest(test.TestCase): class ZeroInitializerOpTest(test.TestCase): def _testZeroInitializer(self, shape, initializer, use_init): - var = variables_lib.Variable(initializer) + var = variables_lib.VariableV1(initializer) var_zero = variables_lib2.zero_initializer(var) with self.cached_session() as sess: with self.assertRaisesOpError('Attempting to use uninitialized value'): @@ -1350,12 +1350,12 @@ class FilterVariablesTest(test.TestCase): g = ops.Graph() with g.as_default(): var_list = [] - var_list.append(variables_lib.Variable(0, name='conv1/weights')) - var_list.append(variables_lib.Variable(0, name='conv1/biases')) - var_list.append(variables_lib.Variable(0, name='conv2/weights')) - var_list.append(variables_lib.Variable(0, name='conv2/biases')) - var_list.append(variables_lib.Variable(0, name='clfs/weights')) - var_list.append(variables_lib.Variable(0, name='clfs/biases')) + var_list.append(variables_lib.VariableV1(0, name='conv1/weights')) + var_list.append(variables_lib.VariableV1(0, name='conv1/biases')) + var_list.append(variables_lib.VariableV1(0, name='conv2/weights')) + var_list.append(variables_lib.VariableV1(0, name='conv2/biases')) + var_list.append(variables_lib.VariableV1(0, name='clfs/weights')) + var_list.append(variables_lib.VariableV1(0, name='clfs/biases')) self._var_list = var_list def _test_filter_variables(self, |