aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/variable_scope_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-07 13:28:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-07 14:33:24 -0700
commitabb87c796f07a2d10eb299c108b229e5c0f1134b (patch)
tree3f25857537afae19f9f5a46e2aaf7747a0670b96 /tensorflow/python/kernel_tests/variable_scope_test.py
parentc37847c1e5aedf5f33151895bdcbf9de89bbd759 (diff)
Sets variable's shape as early as possible.
Change: 135517618
Diffstat (limited to 'tensorflow/python/kernel_tests/variable_scope_test.py')
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 732049cb7a..1e2db3e565 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -635,6 +635,22 @@ class VariableScopeTest(tf.test.TestCase):
self.assertEqual(variable_scope.get_local_variable("w", []).name,
"outer/w:0")
+ def testGetVarWithDevice(self):
+ g = tf.Graph()
+ varname_shape = []
+
+ def device_func(op):
+ if op.type == "Variable":
+ varname_shape.append((op.name, tf.TensorShape(op.get_attr("shape"))))
+ return "/gpu:0"
+
+ with g.as_default():
+ with tf.device(device_func):
+ _ = tf.get_variable("x", (100, 200)) # init fn
+ _ = tf.get_variable("y", initializer=numpy.arange(73)) # init constant
+ self.assertEqual(varname_shape[0], ("x", tf.TensorShape([100, 200])))
+ self.assertEqual(varname_shape[1], ("y", tf.TensorShape([73])))
+
def axis0_into1_partitioner(shape=None, **unused_kwargs):
part = [1] * len(shape)