diff options
Diffstat (limited to 'tensorflow/python/platform/default')
-rwxr-xr-x | tensorflow/python/platform/default/__init__.py | 0 | ||||
-rw-r--r-- | tensorflow/python/platform/default/_app.py | 11 | ||||
-rw-r--r-- | tensorflow/python/platform/default/_flags.py | 92 | ||||
-rw-r--r-- | tensorflow/python/platform/default/_gfile.py | 404 | ||||
-rw-r--r-- | tensorflow/python/platform/default/_googletest.py | 68 | ||||
-rw-r--r-- | tensorflow/python/platform/default/_init.py | 1 | ||||
-rw-r--r-- | tensorflow/python/platform/default/_logging.py | 182 | ||||
-rw-r--r-- | tensorflow/python/platform/default/_parameterized.py | 2 | ||||
-rw-r--r-- | tensorflow/python/platform/default/_resource_loader.py | 26 | ||||
-rw-r--r-- | tensorflow/python/platform/default/_status_bar.py | 5 | ||||
-rw-r--r-- | tensorflow/python/platform/default/flags_test.py | 53 | ||||
-rw-r--r-- | tensorflow/python/platform/default/gfile_test.py | 147 | ||||
-rw-r--r-- | tensorflow/python/platform/default/logging_test.py | 13 |
13 files changed, 1004 insertions, 0 deletions
diff --git a/tensorflow/python/platform/default/__init__.py b/tensorflow/python/platform/default/__init__.py new file mode 100755 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/python/platform/default/__init__.py diff --git a/tensorflow/python/platform/default/_app.py b/tensorflow/python/platform/default/_app.py new file mode 100644 index 0000000000..5917d00ce3 --- /dev/null +++ b/tensorflow/python/platform/default/_app.py @@ -0,0 +1,11 @@ +"""Generic entry point script.""" +import sys + +from tensorflow.python.platform import flags + + +def run(): + f = flags.FLAGS + f._parse_flags() + main = sys.modules['__main__'].main + sys.exit(main(sys.argv)) diff --git a/tensorflow/python/platform/default/_flags.py b/tensorflow/python/platform/default/_flags.py new file mode 100644 index 0000000000..ceccda6e5c --- /dev/null +++ b/tensorflow/python/platform/default/_flags.py @@ -0,0 +1,92 @@ +"""Implementation of the flags interface.""" +import tensorflow.python.platform + +import argparse + +_global_parser = argparse.ArgumentParser() + +class _FlagValues(object): + + def __init__(self): + """Global container and accessor for flags and their values.""" + self.__dict__['__flags'] = {} + self.__dict__['__parsed'] = False + + def _parse_flags(self): + result = _global_parser.parse_args() + for flag_name, val in vars(result).items(): + self.__dict__['__flags'][flag_name] = val + self.__dict__['__parsed'] = True + + def __getattr__(self, name): + """Retrieves the 'value' attribute of the flag --name.""" + if not self.__dict__['__parsed']: + self._parse_flags() + if name not in self.__dict__['__flags']: + raise AttributeError(name) + return self.__dict__['__flags'][name] + + def __setattr__(self, name, value): + """Sets the 'value' attribute of the flag --name.""" + if not self.__dict__['__parsed']: + self._parse_flags() + self.__dict__['__flags'][name] = value + + +def _define_helper(flag_name, default_value, docstring, flagtype): + """Registers 'flag_name' with 'default_value' and 'docstring'.""" + _global_parser.add_argument("--" + flag_name, + default=default_value, + help=docstring, + type=flagtype) + + +# Provides the global object that can be used to access flags. +FLAGS = _FlagValues() + + +def DEFINE_string(flag_name, default_value, docstring): + """Defines a flag of type 'string'. + + Args: + flag_name: The name of the flag as a string. + default_value: The default value the flag should take as a string. + docstring: A helpful message explaining the use of the flag. + """ + _define_helper(flag_name, default_value, docstring, str) + + +def DEFINE_integer(flag_name, default_value, docstring): + """Defines a flag of type 'int'. + + Args: + flag_name: The name of the flag as a string. + default_value: The default value the flag should take as an int. + docstring: A helpful message explaining the use of the flag. + """ + _define_helper(flag_name, default_value, docstring, int) + + +def DEFINE_boolean(flag_name, default_value, docstring): + """Defines a flag of type 'boolean'. + + Args: + flag_name: The name of the flag as a string. + default_value: The default value the flag should take as a boolean. + docstring: A helpful message explaining the use of the flag. + """ + _define_helper(flag_name, default_value, docstring, bool) + _global_parser.add_argument('--no' + flag_name, + action='store_false', + dest=flag_name) + + +def DEFINE_float(flag_name, default_value, docstring): + """Defines a flag of type 'float'. + + Args: + flag_name: The name of the flag as a string. + default_value: The default value the flag should take as a float. + docstring: A helpful message explaining the use of the flag. + """ + _define_helper(flag_name, default_value, docstring, float) diff --git a/tensorflow/python/platform/default/_gfile.py b/tensorflow/python/platform/default/_gfile.py new file mode 100644 index 0000000000..cfd25bdf90 --- /dev/null +++ b/tensorflow/python/platform/default/_gfile.py @@ -0,0 +1,404 @@ +"""File processing utilities.""" + +import errno +import functools +import glob as _glob +import os +import shutil +import threading + + +class FileError(IOError): + """An error occurred while reading or writing a file.""" + + +class GOSError(OSError): + """An error occurred while finding a file or in handling pathnames.""" + + +class _GFileBase(object): + """Base I/O wrapper class. Similar semantics to Python's file object.""" + + # pylint: disable=protected-access + def _error_wrapper(fn): + """Decorator wrapping GFileBase class method errors.""" + @functools.wraps(fn) # Preserve methods' __doc__ + def wrap(self, *args, **kwargs): + try: + return fn(self, *args, **kwargs) + except ValueError, e: + # Sometimes a ValueError is raised, e.g., a read() on a closed file. + raise FileError(errno.EIO, e.message, self._name) + except IOError, e: + e.filename = self._name + raise FileError(e) + except OSError, e: + raise GOSError(e) + return wrap + + def _synchronized(fn): + """Synchronizes file I/O for methods in GFileBase.""" + @functools.wraps(fn) + def sync(self, *args, **kwargs): + # Sometimes a GFileBase method is called before the instance + # has been properly initialized. Check that _locker is available. + if hasattr(self, '_locker'): self._locker.lock() + try: + return fn(self, *args, **kwargs) + finally: + if hasattr(self, '_locker'): self._locker.unlock() + return sync + # pylint: enable=protected-access + + @_error_wrapper + def __init__(self, name, mode, locker): + """Create the GFileBase object with the given filename, mode, and locker. + + Args: + name: string, the filename. + mode: string, the mode to open the file with (e.g. "r", "w", "a+"). + locker: the thread locking object (e.g. _PythonLocker) for controlling + thread access to the I/O methods of this class. + """ + self._name = name + self._mode = mode + self._locker = locker + self._fp = open(name, mode) + + def __enter__(self): + """Make GFileBase usable with "with" statement.""" + return self + + def __exit__(self, unused_type, unused_value, unused_traceback): + """Make GFileBase usable with "with" statement.""" + self.close() + + @_error_wrapper + @_synchronized + def __del__(self): + # __del__ is sometimes called before initialization, in which + # case the object is not fully constructed. Check for this here + # before trying to close the file handle. + if hasattr(self, '_fp'): self._fp.close() + + @_error_wrapper + @_synchronized + def flush(self): + """Flush the underlying file handle.""" + return self._fp.flush() + + @property + @_error_wrapper + @_synchronized + def closed(self): + """Returns "True" if the file handle is closed. Otherwise False.""" + return self._fp.closed + + @_error_wrapper + @_synchronized + def write(self, data): + """Write data to the underlying file handle. + + Args: + data: The string to write to the file handle. + """ + self._fp.write(data) + + @_error_wrapper + @_synchronized + def writelines(self, seq): + """Write a sequence of strings to the underlying file handle.""" + self._fp.writelines(seq) + + @_error_wrapper + @_synchronized + def tell(self): + """Return the location from the underlying file handle. + + Returns: + An integer location (which can be used in e.g., seek). + """ + return self._fp.tell() + + @_error_wrapper + @_synchronized + def seek(self, offset, whence=0): + """Seek to offset (conditioned on whence) in the underlying file handle. + + Args: + offset: int, the offset within the file to seek to. + whence: 0, 1, or 2. See python's seek() documentation for details. + """ + self._fp.seek(offset, whence) + + @_error_wrapper + @_synchronized + def truncate(self, new_size=None): + """Truncate the underlying file handle to new_size. + + Args: + new_size: Size after truncation. If None, the file handle is truncated + to 0 bytes. + """ + self._fp.truncate(new_size) + + @_error_wrapper + @_synchronized + def readline(self, max_length=-1): + """Read a single line (up to max_length) from the underlying file handle. + + Args: + max_length: The maximum number of chsaracters to read. + + Returns: + A string, including any newline at the end, or empty string if at EOF. + """ + return self._fp.readline(max_length) + + @_error_wrapper + @_synchronized + def readlines(self, sizehint=None): + """Read lines from the underlying file handle. + + Args: + sizehint: See the python file.readlines() documentation. + + Returns: + A list of strings from the underlying file handle. + """ + if sizehint is not None: + return self._fp.readlines(sizehint) + else: + return self._fp.readlines() + + def __iter__(self): + """Enable line iteration on the underlying handle (not synchronized).""" + return self + + # Not synchronized + @_error_wrapper + def next(self): + """Enable line iteration on the underlying handle (not synchronized). + + Returns: + An line iterator from the underlying handle. + + Example: + # read a file's lines by consuming the iterator with a list + with open("filename", "r") as fp: lines = list(fp) + """ + return self._fp.next() + + @_error_wrapper + @_synchronized + def Size(self): # pylint: disable=invalid-name + """Get byte size of the file from the underlying file handle.""" + cur = self.tell() + try: + self.seek(0, 2) + size = self.tell() + finally: + self.seek(cur) + return size + + @_error_wrapper + @_synchronized + def read(self, n=-1): + """Read n bytes from the underlying file handle. + + Args: + n: Number of bytes to read (if negative, read to end of file handle.) + + Returns: + A string of the bytes read, up to the end of file. + """ + return self._fp.read(n) + + @_error_wrapper + @_synchronized + def close(self): + """Close the underlying file handle.""" + self._fp.close() + + # Declare wrappers as staticmethods at the end so that we can + # use them as decorators. + _error_wrapper = staticmethod(_error_wrapper) + _synchronized = staticmethod(_synchronized) + + +class GFile(_GFileBase): + """File I/O wrappers with thread locking.""" + + def __init__(self, name, mode='r'): + super(GFile, self).__init__(name, mode, _Pythonlocker()) + + +class FastGFile(_GFileBase): + """File I/O wrappers without thread locking.""" + + def __init__(self, name, mode='r'): + super(FastGFile, self).__init__(name, mode, _Nulllocker()) + + +# locker classes. Note that locks must be reentrant, so that multiple +# lock() calls by the owning thread will not block. +class _Pythonlocker(object): + """A locking strategy that uses standard locks from the thread module.""" + + def __init__(self): + self._lock = threading.RLock() + + def lock(self): + self._lock.acquire() + + def unlock(self): + self._lock.release() + + +class _Nulllocker(object): + """A locking strategy where lock() and unlock() methods are no-ops.""" + + def lock(self): + pass + + def unlock(self): + pass + + +def _func_error_wrapper(fn): + """Decorator wrapping function errors.""" + @functools.wraps(fn) # Preserve methods' __doc__ + def wrap(*args, **kwargs): + try: + return fn(*args, **kwargs) + except ValueError, e: + raise FileError(errno.EIO, e.message) + except IOError, e: + raise FileError(e) + except OSError, e: + raise GOSError(e) + return wrap + + +@_func_error_wrapper +def Exists(path): # pylint: disable=invalid-name + """Retruns True iff "path" exists (as a dir, file, non-broken symlink).""" + return os.path.exists(path) + + +@_func_error_wrapper +def IsDirectory(path): # pylint: disable=invalid-name + """Return True iff "path" exists and is a directory.""" + return os.path.isdir(path) + + +@_func_error_wrapper +def Glob(glob): # pylint: disable=invalid-name + """Return a list of filenames matching the glob "glob".""" + return _glob.glob(glob) + + +@_func_error_wrapper +def MkDir(path, mode=0755): # pylint: disable=invalid-name + """Create the directory "path" with the given mode. + + Args: + path: The directory path + mode: The file mode for the directory + + Returns: + None + + Raises: + GOSError: if the path already exists + """ + os.mkdir(path, mode) + + +@_func_error_wrapper +def MakeDirs(path, mode=0755): # pylint: disable=invalid-name + """Recursively create the directory "path" with the given mode. + + Args: + path: The directory path + mode: The file mode for the created directories + + Returns: + None + + + Raises: + GOSError: if the path already exists + """ + os.makedirs(path, mode) + + +@_func_error_wrapper +def RmDir(directory): # pylint: disable=invalid-name + """Removes the directory "directory" iff the directory is empty. + + Args: + directory: The directory to remove. + + Raises: + GOSError: If the directory does not exist or is not empty. + """ + os.rmdir(directory) + + +@_func_error_wrapper +def Remove(path): # pylint: disable=invalid-name + """Delete the (non-directory) file "path". + + Args: + path: The file to remove. + + Raises: + GOSError: If "path" does not exist, is a directory, or cannot be deleted. + """ + os.remove(path) + + +@_func_error_wrapper +def DeleteRecursively(path): # pylint: disable=invalid-name + """Delete the file or directory "path" recursively. + + Args: + path: The path to remove (may be a non-empty directory). + + Raises: + GOSError: If the path does not exist or cannot be deleted. + """ + if IsDirectory(path): + shutil.rmtree(path) + else: + Remove(path) + + +@_func_error_wrapper +def ListDirectory(directory, return_dotfiles=False): # pylint: disable=invalid-name + """Returns a list of files in dir. + + As with the standard os.listdir(), the filenames in the returned list will be + the basenames of the files in dir (not absolute paths). To get a list of + absolute paths of files in a directory, a client could do: + file_list = gfile.ListDir(my_dir) + file_list = [os.path.join(my_dir, f) for f in file_list] + (assuming that my_dir itself specified an absolute path to a directory). + + Args: + directory: the directory to list + return_dotfiles: if True, dotfiles will be returned as well. Even if + this arg is True, '.' and '..' will not be returned. + + Returns: + ['list', 'of', 'files']. The entries '.' and '..' are never returned. + Other entries starting with a dot will only be returned if return_dotfiles + is True. + Raises: + GOSError: if there is an error retrieving the directory listing. + """ + files = os.listdir(directory) + if not return_dotfiles: + files = [f for f in files if not f.startswith('.')] + return files diff --git a/tensorflow/python/platform/default/_googletest.py b/tensorflow/python/platform/default/_googletest.py new file mode 100644 index 0000000000..d2686565a0 --- /dev/null +++ b/tensorflow/python/platform/default/_googletest.py @@ -0,0 +1,68 @@ +"""Imports unittest as a replacement for testing.pybase.googletest.""" +import inspect +import itertools +import os +import tempfile + +# pylint: disable=wildcard-import +from unittest import * + + +unittest_main = main + + +# pylint: disable=invalid-name +# pylint: disable=undefined-variable +def main(*args, **kwargs): + """Delegate to unittest.main after redefining testLoader.""" + if 'TEST_SHARD_STATUS_FILE' in os.environ: + try: + f = None + try: + f = open(os.environ['TEST_SHARD_STATUS_FILE'], 'w') + f.write('') + except IOError: + sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.' + % os.environ['TEST_SHARD_STATUS_FILE']) + sys.exit(1) + finally: + if f is not None: f.close() + + if ('TEST_TOTAL_SHARDS' not in os.environ or + 'TEST_SHARD_INDEX' not in os.environ): + return unittest_main(*args, **kwargs) + + total_shards = int(os.environ['TEST_TOTAL_SHARDS']) + shard_index = int(os.environ['TEST_SHARD_INDEX']) + base_loader = TestLoader() + + delegate_get_names = base_loader.getTestCaseNames + bucket_iterator = itertools.cycle(range(total_shards)) + + def getShardedTestCaseNames(testCaseClass): + filtered_names = [] + for testcase in sorted(delegate_get_names(testCaseClass)): + bucket = bucket_iterator.next() + if bucket == shard_index: + filtered_names.append(testcase) + return filtered_names + + # Override getTestCaseNames + base_loader.getTestCaseNames = getShardedTestCaseNames + + kwargs['testLoader'] = base_loader + unittest_main(*args, **kwargs) + + +def GetTempDir(): + first_frame = inspect.stack()[-1][0] + temp_dir = os.path.join( + tempfile.gettempdir(), os.path.basename(inspect.getfile(first_frame))) + temp_dir = temp_dir.rstrip('.py') + if not os.path.isdir(temp_dir): + os.mkdir(temp_dir, 0755) + return temp_dir + + +def StatefulSessionAvailable(): + return False diff --git a/tensorflow/python/platform/default/_init.py b/tensorflow/python/platform/default/_init.py new file mode 100644 index 0000000000..916d598856 --- /dev/null +++ b/tensorflow/python/platform/default/_init.py @@ -0,0 +1 @@ +# Nothing to do for default platform diff --git a/tensorflow/python/platform/default/_logging.py b/tensorflow/python/platform/default/_logging.py new file mode 100644 index 0000000000..2e289b1abe --- /dev/null +++ b/tensorflow/python/platform/default/_logging.py @@ -0,0 +1,182 @@ +"""Logging utilities.""" +# pylint: disable=unused-import +# pylint: disable=g-bad-import-order +# pylint: disable=invalid-name +import os +import sys +import time +import thread +from logging import getLogger +from logging import log +from logging import debug +from logging import error +from logging import fatal +from logging import info +from logging import warn +from logging import warning +from logging import DEBUG +from logging import ERROR +from logging import FATAL +from logging import INFO +from logging import WARN + +# Controls which methods from pyglib.logging are available within the project +# Do not add methods here without also adding to platform/default/_logging.py +__all__ = ['log', 'debug', 'error', 'fatal', 'info', 'warn', 'warning', + 'DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN', + 'flush', 'log_every_n', 'log_first_n', 'vlog', + 'TaskLevelStatusMessage', 'get_verbosity', 'set_verbosity'] + +warning = warn + +_level_names = { + FATAL: 'FATAL', + ERROR: 'ERROR', + WARN: 'WARN', + INFO: 'INFO', + DEBUG: 'DEBUG', +} + +# Mask to convert integer thread ids to unsigned quantities for logging +# purposes +_THREAD_ID_MASK = 2 * sys.maxint + 1 + +_log_prefix = None # later set to google2_log_prefix + +# Counter to keep track of number of log entries per token. +_log_counter_per_token = {} + + +def TaskLevelStatusMessage(msg): + error(msg) + + +def flush(): + raise NotImplementedError() + + +# Code below is taken from pyglib/logging +def vlog(level, msg, *args, **kwargs): + log(level, msg, *args, **kwargs) + + +def _GetNextLogCountPerToken(token): + """Wrapper for _log_counter_per_token. + + Args: + token: The token for which to look up the count. + + Returns: + The number of times this function has been called with + *token* as an argument (starting at 0) + """ + global _log_counter_per_token # pylint: disable=global-variable-not-assigned + _log_counter_per_token[token] = 1 + _log_counter_per_token.get(token, -1) + return _log_counter_per_token[token] + + +def log_every_n(level, msg, n, *args): + """Log 'msg % args' at level 'level' once per 'n' times. + + Logs the 1st call, (N+1)st call, (2N+1)st call, etc. + Not threadsafe. + + Args: + level: The level at which to log. + msg: The message to be logged. + n: The number of times this should be called before it is logged. + *args: The args to be substituted into the msg. + """ + count = _GetNextLogCountPerToken(_GetFileAndLine()) + log_if(level, msg, not (count % n), *args) + + +def log_first_n(level, msg, n, *args): # pylint: disable=g-bad-name + """Log 'msg % args' at level 'level' only first 'n' times. + + Not threadsafe. + + Args: + level: The level at which to log. + msg: The message to be logged. + n: The number of times this should be called before it is logged. + *args: The args to be substituted into the msg. + """ + count = _GetNextLogCountPerToken(_GetFileAndLine()) + log_if(level, msg, count < n, *args) + + +def log_if(level, msg, condition, *args): + """Log 'msg % args' at level 'level' only if condition is fulfilled.""" + if condition: + vlog(level, msg, *args) + + +def _GetFileAndLine(): + """Returns (filename, linenumber) for the stack frame.""" + # Use sys._getframe(). This avoids creating a traceback object. + # pylint: disable=protected-access + f = sys._getframe() + # pylint: enable=protected-access + our_file = f.f_code.co_filename + f = f.f_back + while f: + code = f.f_code + if code.co_filename != our_file: + return (code.co_filename, f.f_lineno) + f = f.f_back + return ('<unknown>', 0) + + +def google2_log_prefix(level, timestamp=None, file_and_line=None): + """Assemble a logline prefix using the google2 format.""" + # pylint: disable=global-variable-not-assigned + global _level_names + global _logfile_map, _logfile_map_mutex + # pylint: enable=global-variable-not-assigned + + # Record current time + now = timestamp or time.time() + now_tuple = time.localtime(now) + now_microsecond = int(1e6 * (now % 1.0)) + + (filename, line) = file_and_line or _GetFileAndLine() + basename = os.path.basename(filename) + + # Severity string + severity = 'I' + if level in _level_names: + severity = _level_names[level][0] + + s = '%c%02d%02d %02d:%02d:%02d.%06d %5d %s:%d] ' % ( + severity, + now_tuple[1], # month + now_tuple[2], # day + now_tuple[3], # hour + now_tuple[4], # min + now_tuple[5], # sec + now_microsecond, + _get_thread_id(), + basename, + line) + + return s + + +def get_verbosity(): + """Return how much logging output will be produced.""" + return getLogger().getEffectiveLevel() + + +def set_verbosity(verbosity): + """Sets the threshold for what messages will be logged.""" + getLogger().setLevel(verbosity) + + +def _get_thread_id(): + """Get id of current thread, suitable for logging as an unsigned quantity.""" + thread_id = thread.get_ident() + return thread_id & _THREAD_ID_MASK + + +_log_prefix = google2_log_prefix diff --git a/tensorflow/python/platform/default/_parameterized.py b/tensorflow/python/platform/default/_parameterized.py new file mode 100644 index 0000000000..5d141568ed --- /dev/null +++ b/tensorflow/python/platform/default/_parameterized.py @@ -0,0 +1,2 @@ +"""Extension to unittest to run parameterized tests.""" +raise ImportError("Not implemented yet.") diff --git a/tensorflow/python/platform/default/_resource_loader.py b/tensorflow/python/platform/default/_resource_loader.py new file mode 100644 index 0000000000..69f425072f --- /dev/null +++ b/tensorflow/python/platform/default/_resource_loader.py @@ -0,0 +1,26 @@ +"""Read a file and return its contents.""" + +import os.path + +from tensorflow.python.platform import logging + + +def load_resource(path): + """Load the resource at given path, where path is relative to tensorflow/. + + Args: + path: a string resource path relative to tensorflow/. + + Returns: + The contents of that resource. + + Raises: + IOError: If the path is not found, or the resource can't be opened. + """ + path = os.path.join('tensorflow', path) + path = os.path.abspath(path) + try: + with open(path, 'rb') as f: + return f.read() + except IOError as e: + logging.warning('IOError %s on path %s' % (e, path)) diff --git a/tensorflow/python/platform/default/_status_bar.py b/tensorflow/python/platform/default/_status_bar.py new file mode 100644 index 0000000000..2953908724 --- /dev/null +++ b/tensorflow/python/platform/default/_status_bar.py @@ -0,0 +1,5 @@ +"""A no-op implementation of status bar functions.""" + + +def SetupStatusBarInsideGoogle(unused_link_text, unused_port): + pass diff --git a/tensorflow/python/platform/default/flags_test.py b/tensorflow/python/platform/default/flags_test.py new file mode 100644 index 0000000000..1b15ca138a --- /dev/null +++ b/tensorflow/python/platform/default/flags_test.py @@ -0,0 +1,53 @@ +"""Tests for our flags implementation.""" +import sys + +from tensorflow.python.platform.default import _googletest as googletest + +from tensorflow.python.platform.default import _flags as flags + + +flags.DEFINE_string("string_foo", "default_val", "HelpString") +flags.DEFINE_boolean("bool_foo", True, "HelpString") +flags.DEFINE_integer("int_foo", 42, "HelpString") +flags.DEFINE_float("float_foo", 42.0, "HelpString") + +FLAGS = flags.FLAGS + +class FlagsTest(googletest.TestCase): + + def testString(self): + res = FLAGS.string_foo + self.assertEqual(res, "default_val") + FLAGS.string_foo = "bar" + self.assertEqual("bar", FLAGS.string_foo) + + def testBool(self): + res = FLAGS.bool_foo + self.assertTrue(res) + FLAGS.bool_foo = False + self.assertFalse(FLAGS.bool_foo) + + def testNoBool(self): + FLAGS.bool_foo = True + try: + sys.argv.append("--nobool_foo") + FLAGS._parse_flags() + self.assertFalse(FLAGS.bool_foo) + finally: + sys.argv.pop() + + def testInt(self): + res = FLAGS.int_foo + self.assertEquals(res, 42) + FLAGS.int_foo = -1 + self.assertEqual(-1, FLAGS.int_foo) + + def testFloat(self): + res = FLAGS.float_foo + self.assertEquals(42.0, res) + FLAGS.float_foo = -1.0 + self.assertEqual(-1.0, FLAGS.float_foo) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/platform/default/gfile_test.py b/tensorflow/python/platform/default/gfile_test.py new file mode 100644 index 0000000000..9eec952e95 --- /dev/null +++ b/tensorflow/python/platform/default/gfile_test.py @@ -0,0 +1,147 @@ +import os +import shutil + +from tensorflow.python.platform.default import _gfile as gfile +from tensorflow.python.platform.default import _googletest as googletest +from tensorflow.python.platform.default import _logging as logging + + +class _BaseTest(object): + + @property + def tmp(self): + return self._tmp_dir + + def setUp(self): + self._orig_dir = os.getcwd() + self._tmp_dir = googletest.GetTempDir() + "/" + try: + os.makedirs(self._tmp_dir) + except OSError: + pass # Directory already exists + + def tearDown(self): + try: + shutil.rmtree(self._tmp_dir) + except OSError: + logging.warn("[%s] Post-test directory cleanup failed: %s" + % (self, self._tmp_dir)) + + +class _GFileBaseTest(_BaseTest): + + @property + def gfile(self): + raise NotImplementedError("Do not use _GFileBaseTest directly.") + + def testWith(self): + with self.gfile(self.tmp + "test_with", "w") as fh: + fh.write("hi") + with self.gfile(self.tmp + "test_with", "r") as fh: + self.assertEquals(fh.read(), "hi") + + def testSizeAndTellAndSeek(self): + with self.gfile(self.tmp + "test_tell", "w") as fh: + fh.write("".join(["0"] * 1000)) + with self.gfile(self.tmp + "test_tell", "r") as fh: + self.assertEqual(1000, fh.Size()) + self.assertEqual(0, fh.tell()) + fh.seek(0, 2) + self.assertEqual(1000, fh.tell()) + fh.seek(0) + self.assertEqual(0, fh.tell()) + + def testReadAndWritelines(self): + with self.gfile(self.tmp + "test_writelines", "w") as fh: + fh.writelines(["%d\n" % d for d in range(10)]) + with self.gfile(self.tmp + "test_writelines", "r") as fh: + self.assertEqual(["%d\n" % x for x in range(10)], fh.readlines()) + + def testWriteAndTruncate(self): + with self.gfile(self.tmp + "test_truncate", "w") as fh: + fh.write("ababab") + with self.gfile(self.tmp + "test_truncate", "a+") as fh: + fh.seek(0, 2) + fh.write("hjhjhj") + with self.gfile(self.tmp + "test_truncate", "a+") as fh: + self.assertEqual(fh.Size(), 12) + fh.truncate(6) + with self.gfile(self.tmp + "test_truncate", "r") as fh: + self.assertEqual(fh.read(), "ababab") + + def testErrors(self): + self.assertRaises( + gfile.FileError, lambda: self.gfile(self.tmp + "doesnt_exist", "r")) + with self.gfile(self.tmp + "test_error", "w") as fh: + self.assertRaises(gfile.FileError, lambda: fh.seek(-1)) + # test_error now exists, we can read from it: + with self.gfile(self.tmp + "test_error", "r") as fh: + self.assertRaises(gfile.FileError, lambda: fh.write("ack")) + fh = self.gfile(self.tmp + "test_error", "w") + self.assertFalse(fh.closed) + fh.close() + self.assertTrue(fh.closed) + self.assertRaises(gfile.FileError, lambda: fh.write("ack")) + + def testIteration(self): + with self.gfile(self.tmp + "test_iter", "w") as fh: + fh.writelines(["a\n", "b\n", "c\n"]) + with self.gfile(self.tmp + "test_iter", "r") as fh: + lines = list(fh) + self.assertEqual(["a\n", "b\n", "c\n"], lines) + + +class GFileTest(_GFileBaseTest, googletest.TestCase): + + @property + def gfile(self): + return gfile.GFile + + +class FastGFileTest(_GFileBaseTest, googletest.TestCase): + + @property + def gfile(self): + return gfile.FastGFile + + +class FunctionTests(_BaseTest, googletest.TestCase): + + def testExists(self): + self.assertFalse(gfile.Exists(self.tmp + "test_exists")) + with gfile.GFile(self.tmp + "test_exists", "w"): + pass + self.assertTrue(gfile.Exists(self.tmp + "test_exists")) + + def testMkDirsGlobAndRmDirs(self): + self.assertFalse(gfile.Exists(self.tmp + "test_dir")) + gfile.MkDir(self.tmp + "test_dir") + self.assertTrue(gfile.Exists(self.tmp + "test_dir")) + gfile.RmDir(self.tmp + "test_dir") + self.assertFalse(gfile.Exists(self.tmp + "test_dir")) + gfile.MakeDirs(self.tmp + "test_dir/blah0") + gfile.MakeDirs(self.tmp + "test_dir/blah1") + self.assertEqual([self.tmp + "test_dir/blah0", self.tmp + "test_dir/blah1"], + sorted(gfile.Glob(self.tmp + "test_dir/*"))) + gfile.DeleteRecursively(self.tmp + "test_dir") + self.assertFalse(gfile.Exists(self.tmp + "test_dir")) + + def testErrors(self): + self.assertRaises( + gfile.GOSError, lambda: gfile.RmDir(self.tmp + "dir_doesnt_exist")) + self.assertRaises( + gfile.GOSError, lambda: gfile.Remove(self.tmp + "file_doesnt_exist")) + gfile.MkDir(self.tmp + "error_dir") + with gfile.GFile(self.tmp + "error_dir/file", "w"): + pass # Create file + self.assertRaises( + gfile.GOSError, lambda: gfile.Remove(self.tmp + "error_dir")) + self.assertRaises( + gfile.GOSError, lambda: gfile.RmDir(self.tmp + "error_dir")) + self.assertTrue(gfile.Exists(self.tmp + "error_dir")) + gfile.DeleteRecursively(self.tmp + "error_dir") + self.assertFalse(gfile.Exists(self.tmp + "error_dir")) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/platform/default/logging_test.py b/tensorflow/python/platform/default/logging_test.py new file mode 100644 index 0000000000..fd492bc384 --- /dev/null +++ b/tensorflow/python/platform/default/logging_test.py @@ -0,0 +1,13 @@ +from tensorflow.python.platform.default import _googletest as googletest +from tensorflow.python.platform.default import _logging as logging + + +class EventLoaderTest(googletest.TestCase): + + def test_log(self): + # Just check that logging works without raising an exception. + logging.error("test log message") + + +if __name__ == "__main__": + googletest.main() |