diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-10-07 13:28:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-07 14:33:24 -0700 |
commit | abb87c796f07a2d10eb299c108b229e5c0f1134b (patch) | |
tree | 3f25857537afae19f9f5a46e2aaf7747a0670b96 /tensorflow/python/kernel_tests/variable_scope_test.py | |
parent | c37847c1e5aedf5f33151895bdcbf9de89bbd759 (diff) |
Sets variable's shape as early as possible.
Change: 135517618
Diffstat (limited to 'tensorflow/python/kernel_tests/variable_scope_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/variable_scope_test.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 732049cb7a..1e2db3e565 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -635,6 +635,22 @@ class VariableScopeTest(tf.test.TestCase): self.assertEqual(variable_scope.get_local_variable("w", []).name, "outer/w:0") + def testGetVarWithDevice(self): + g = tf.Graph() + varname_shape = [] + + def device_func(op): + if op.type == "Variable": + varname_shape.append((op.name, tf.TensorShape(op.get_attr("shape")))) + return "/gpu:0" + + with g.as_default(): + with tf.device(device_func): + _ = tf.get_variable("x", (100, 200)) # init fn + _ = tf.get_variable("y", initializer=numpy.arange(73)) # init constant + self.assertEqual(varname_shape[0], ("x", tf.TensorShape([100, 200]))) + self.assertEqual(varname_shape[1], ("y", tf.TensorShape([73]))) + def axis0_into1_partitioner(shape=None, **unused_kwargs): part = [1] * len(shape) |