aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/saver_large_partitioned_variable_test.py
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2016-12-16 18:16:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-16 18:26:25 -0800
commit58201a058853de647b37ddb0ccf63d89b2357f03 (patch)
treea70f1626ef6c6bae985c3b581e2a6f9ce6f6b96e /tensorflow/python/training/saver_large_partitioned_variable_test.py
parent157dbdda6d7d6e4f679c5e89d0bd3e8c0a6085d5 (diff)
Remove hourglass imports from even more tests
Change: 142318245
Diffstat (limited to 'tensorflow/python/training/saver_large_partitioned_variable_test.py')
-rw-r--r--tensorflow/python/training/saver_large_partitioned_variable_test.py28
1 files changed, 18 insertions, 10 deletions
diff --git a/tensorflow/python/training/saver_large_partitioned_variable_test.py b/tensorflow/python/training/saver_large_partitioned_variable_test.py
index 5af5cda342..1a44511cfe 100644
--- a/tensorflow/python/training/saver_large_partitioned_variable_test.py
+++ b/tensorflow/python/training/saver_large_partitioned_variable_test.py
@@ -13,16 +13,24 @@
# limitations under the License.
# =============================================================================
"""Tests for tensorflow.python.training.saver.py."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
-import tensorflow as tf
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver
-class SaverLargePartitionedVariableTest(tf.test.TestCase):
+class SaverLargePartitionedVariableTest(test.TestCase):
# Need to do this in a separate test because of the amount of memory needed
# to run this test.
@@ -30,19 +38,19 @@ class SaverLargePartitionedVariableTest(tf.test.TestCase):
save_path = os.path.join(self.get_temp_dir(), "large_variable")
var_name = "my_var"
# Saving large partition variable.
- with tf.Session("", graph=tf.Graph()) as sess:
- with tf.device("/cpu:0"):
+ with session.Session("", graph=ops.Graph()) as sess:
+ with ops.device("/cpu:0"):
# Create a partitioned variable which is larger than int32 size but
# split into smaller sized variables.
- init = lambda shape, dtype, partition_info: tf.constant(
+ init = lambda shape, dtype, partition_info: constant_op.constant(
True, dtype, shape)
- partitioned_var = tf.create_partitioned_variables(
- [1 << 31], [4], init, dtype=tf.bool, name=var_name)
- tf.global_variables_initializer().run()
- save = tf.train.Saver(partitioned_var)
+ partitioned_var = partitioned_variables.create_partitioned_variables(
+ [1 << 31], [4], init, dtype=dtypes.bool, name=var_name)
+ variables.global_variables_initializer().run()
+ save = saver.Saver(partitioned_var)
val = save.save(sess, save_path)
self.assertEqual(save_path, val)
if __name__ == "__main__":
- tf.test.main()
+ test.main()