diff options
author | 2016-12-16 18:16:15 -0800 | |
---|---|---|
committer | 2016-12-16 18:26:25 -0800 | |
commit | 58201a058853de647b37ddb0ccf63d89b2357f03 (patch) | |
tree | a70f1626ef6c6bae985c3b581e2a6f9ce6f6b96e /tensorflow/python/training/saver_large_variable_test.py | |
parent | 157dbdda6d7d6e4f679c5e89d0bd3e8c0a6085d5 (diff) |
Remove hourglass imports from even more tests
Change: 142318245
Diffstat (limited to 'tensorflow/python/training/saver_large_variable_test.py')
-rw-r--r-- | tensorflow/python/training/saver_large_variable_test.py | 36 |
1 files changed, 23 insertions, 13 deletions
diff --git a/tensorflow/python/training/saver_large_variable_test.py b/tensorflow/python/training/saver_large_variable_test.py index 1e6d9e0c77..9d171ea568 100644 --- a/tensorflow/python/training/saver_large_variable_test.py +++ b/tensorflow/python/training/saver_large_variable_test.py @@ -12,39 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= - """Tests for tensorflow.python.training.saver.py.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function import os -import tensorflow as tf +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver -class SaverLargeVariableTest(tf.test.TestCase): +class SaverLargeVariableTest(test.TestCase): # NOTE: This is in a separate file from saver_test.py because the # large allocations do not play well with TSAN, and cause flaky # failures. def testLargeVariable(self): save_path = os.path.join(self.get_temp_dir(), "large_variable") - with tf.Session("", graph=tf.Graph()) as sess: + with session.Session("", graph=ops.Graph()) as sess: # Declare a variable that is exactly 2GB. This should fail, # because a serialized checkpoint includes other header # metadata. - with tf.device("/cpu:0"): - var = tf.Variable( - tf.constant(False, shape=[2, 1024, 1024, 1024], dtype=tf.bool)) - save = tf.train.Saver({var.op.name: var}, - write_version=tf.train.SaverDef.V1) + with ops.device("/cpu:0"): + var = variables.Variable( + constant_op.constant( + False, shape=[2, 1024, 1024, 1024], dtype=dtypes.bool)) + save = saver.Saver( + { + var.op.name: var + }, write_version=saver_pb2.SaverDef.V1) var.initializer.run() - with self.assertRaisesRegexp( - tf.errors.InvalidArgumentError, - "Tensor slice is too large to serialize"): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Tensor slice is too large to serialize"): save.save(sess, save_path) if __name__ == "__main__": - tf.test.main() + test.main() |