aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/partitioned_variables_test.py
diff options
context:
space:
mode:
authorGravatar Wei Ho <weiho4+github@gmail.com>2016-05-31 15:14:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-31 16:18:16 -0700
commit775a7490047299f2b64bf4fafcb6504852b8082d (patch)
tree70a7f3f647cddde050826c2bdf881511d08239b4 /tensorflow/python/kernel_tests/partitioned_variables_test.py
parent4752b3c78e07adec00b553e2ca2a439e1e563956 (diff)
Skeleton for PartitionedVariable class take 2
Change: 123694865
Diffstat (limited to 'tensorflow/python/kernel_tests/partitioned_variables_test.py')
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py26
1 files changed, 13 insertions, 13 deletions
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index 5a9ad22349..7a50f94c07 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -22,12 +22,6 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
-from tensorflow.python.ops import variable_scope
-
-# pylint: disable=protected-access
-get_partitioned_variable_list = variable_scope._get_partitioned_variable_list
-# pylint: enable=protected-access
-
class PartitionerCreatorsTest(tf.test.TestCase):
@@ -39,8 +33,9 @@ class PartitionerCreatorsTest(tf.test.TestCase):
axis=axis, max_shard_bytes=max_shard_bytes, max_shards=max_shards)
with tf.variable_scope("root", partitioner=partitioner):
- v0_list, v0_part = get_partitioned_variable_list(
- name, dtype=tf.float32, shape=(4, 8, 16, 32))
+ v0 = tf.get_variable(name, dtype=tf.float32, shape=(4, 8, 16, 32))
+ v0_list = v0._get_variable_list()
+ v0_part = v0._get_partitions()
self.assertEqual(len(v0_list), expected_axis_shards)
self.assertAllEqual(v0_part, expected_partitions)
@@ -118,10 +113,13 @@ class PartitionerCreatorsTest(tf.test.TestCase):
axis=3, max_shard_bytes=32768, bytes_per_string_element=8)
with tf.variable_scope("root", partitioner=partitioner_axis3_str):
- v3str_list, v3str_part = get_partitioned_variable_list(
+ v3str = tf.get_variable(
"v3str",
- initializer=np.array([""] * 4*8*16*32).reshape(4, 8, 16, 32),
- dtype=tf.string, shape=(4, 8, 16, 32))
+ initializer=np.array([""] * 4 * 8 * 16 * 32).reshape(4, 8, 16, 32),
+ dtype=tf.string,
+ shape=(4, 8, 16, 32))
+ v3str_list = v3str._get_variable_list()
+ v3str_part = v3str._get_partitions()
# Now the estimated bytes_per_slice = 4*8*16*bytes_per_string_element
# which is equal to 4096. Setting a max_shard_bytes of 32768
@@ -191,9 +189,11 @@ class PartitionedVariablesTestCase(tf.test.TestCase):
with self.test_session():
rnd_par = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
with tf.variable_scope("hola") as vs:
- vs1 = tf.create_partitioned_variables([2, 4], [1, 2], rnd_par)
+ vs1 = tf.create_partitioned_variables(
+ [2, 4], [1, 2], rnd_par, dtype=tf.int32)
with tf.variable_scope(vs, reuse=True):
- vs2 = tf.create_partitioned_variables([2, 4], [1, 2], rnd_par)
+ vs2 = tf.create_partitioned_variables(
+ [2, 4], [1, 2], rnd_par, dtype=tf.int32)
tf.initialize_all_variables().run()
var1_name = vs1[0]._save_slice_info.full_name
var2_name = vs2[0]._save_slice_info.full_name