diff options
author | 2016-05-31 15:14:19 -0800 | |
---|---|---|
committer | 2016-05-31 16:18:16 -0700 | |
commit | 775a7490047299f2b64bf4fafcb6504852b8082d (patch) | |
tree | 70a7f3f647cddde050826c2bdf881511d08239b4 /tensorflow/python/kernel_tests/partitioned_variables_test.py | |
parent | 4752b3c78e07adec00b553e2ca2a439e1e563956 (diff) |
Skeleton for PartitionedVariable class take 2
Change: 123694865
Diffstat (limited to 'tensorflow/python/kernel_tests/partitioned_variables_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/partitioned_variables_test.py | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py index 5a9ad22349..7a50f94c07 100644 --- a/tensorflow/python/kernel_tests/partitioned_variables_test.py +++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py @@ -22,12 +22,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf -from tensorflow.python.ops import variable_scope - -# pylint: disable=protected-access -get_partitioned_variable_list = variable_scope._get_partitioned_variable_list -# pylint: enable=protected-access - class PartitionerCreatorsTest(tf.test.TestCase): @@ -39,8 +33,9 @@ class PartitionerCreatorsTest(tf.test.TestCase): axis=axis, max_shard_bytes=max_shard_bytes, max_shards=max_shards) with tf.variable_scope("root", partitioner=partitioner): - v0_list, v0_part = get_partitioned_variable_list( - name, dtype=tf.float32, shape=(4, 8, 16, 32)) + v0 = tf.get_variable(name, dtype=tf.float32, shape=(4, 8, 16, 32)) + v0_list = v0._get_variable_list() + v0_part = v0._get_partitions() self.assertEqual(len(v0_list), expected_axis_shards) self.assertAllEqual(v0_part, expected_partitions) @@ -118,10 +113,13 @@ class PartitionerCreatorsTest(tf.test.TestCase): axis=3, max_shard_bytes=32768, bytes_per_string_element=8) with tf.variable_scope("root", partitioner=partitioner_axis3_str): - v3str_list, v3str_part = get_partitioned_variable_list( + v3str = tf.get_variable( "v3str", - initializer=np.array([""] * 4*8*16*32).reshape(4, 8, 16, 32), - dtype=tf.string, shape=(4, 8, 16, 32)) + initializer=np.array([""] * 4 * 8 * 16 * 32).reshape(4, 8, 16, 32), + dtype=tf.string, + shape=(4, 8, 16, 32)) + v3str_list = v3str._get_variable_list() + v3str_part = v3str._get_partitions() # Now the estimated bytes_per_slice = 4*8*16*bytes_per_string_element # which is equal to 4096. Setting a max_shard_bytes of 32768 @@ -191,9 +189,11 @@ class PartitionedVariablesTestCase(tf.test.TestCase): with self.test_session(): rnd_par = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]]) with tf.variable_scope("hola") as vs: - vs1 = tf.create_partitioned_variables([2, 4], [1, 2], rnd_par) + vs1 = tf.create_partitioned_variables( + [2, 4], [1, 2], rnd_par, dtype=tf.int32) with tf.variable_scope(vs, reuse=True): - vs2 = tf.create_partitioned_variables([2, 4], [1, 2], rnd_par) + vs2 = tf.create_partitioned_variables( + [2, 4], [1, 2], rnd_par, dtype=tf.int32) tf.initialize_all_variables().run() var1_name = vs1[0]._save_slice_info.full_name var2_name = vs2[0]._save_slice_info.full_name |