diff options
author | 2017-10-27 12:16:44 -0700 | |
---|---|---|
committer | 2017-10-27 12:26:26 -0700 | |
commit | 9158f974a346b4fae89044d8724eb052d466112b (patch) | |
tree | b27152ad24984c8059f34cc708aa35a6fe430541 /tensorflow/tools/gcs_test | |
parent | 3d39b32b9a7833807dad037c3f57c818e9251f85 (diff) |
Use tf.app.run in gcs_smoke, so that the flags are explicitly parsed, instead of parsed when first accessed.
PiperOrigin-RevId: 173702828
Diffstat (limited to 'tensorflow/tools/gcs_test')
-rw-r--r-- | tensorflow/tools/gcs_test/python/gcs_smoke.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/tools/gcs_test/python/gcs_smoke.py b/tensorflow/tools/gcs_test/python/gcs_smoke.py index 9882f75a8a..ad4cb17ae1 100644 --- a/tensorflow/tools/gcs_test/python/gcs_smoke.py +++ b/tensorflow/tools/gcs_test/python/gcs_smoke.py @@ -35,6 +35,7 @@ flags.DEFINE_integer("num_examples", 10, "Number of examples to generate") FLAGS = flags.FLAGS + def create_examples(num_examples, input_mean): """Create ExampleProto's containing data.""" ids = np.arange(num_examples).reshape([num_examples, 1]) @@ -49,6 +50,7 @@ def create_examples(num_examples, input_mean): examples.append(ex) return examples + def create_dir_test(): """Verifies file_io directory handling methods.""" @@ -122,6 +124,7 @@ def create_dir_test(): print("Deleted directory recursively %s in %s milliseconds" % ( dir_name, elapsed_ms)) + def create_object_test(): """Verifies file_io's object manipulation methods .""" starttime_ms = int(round(time.time() * 1000)) @@ -142,7 +145,8 @@ def create_object_test(): print("Creating file %s." % file_name) file_io.write_string_to_file(file_name, "test file creation.") elapsed_ms = int(round(time.time() * 1000)) - starttime_ms - print("Created %d files in %s milliseconds" % (len(files_to_create), elapsed_ms)) + print("Created %d files in %s milliseconds" % ( + len(files_to_create), elapsed_ms)) # Listing files of pattern1. list_files_pattern = "%s/test_file*.txt" % dir_name @@ -185,7 +189,9 @@ def create_object_test(): file_io.delete_recursively(dir_name) -if __name__ == "__main__": +def main(argv): + del argv # Unused. + # Sanity check on the GCS bucket URL. if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"): print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url) @@ -210,7 +216,7 @@ if __name__ == "__main__": # tf_record_iterator works. record_iter = tf.python_io.tf_record_iterator(input_path) read_count = 0 - for r in record_iter: + for _ in record_iter: read_count += 1 print("Read %d records using tf_record_iterator" % read_count) @@ -222,7 +228,7 @@ if __name__ == "__main__": # Verify that running the read op in a session works. print("\n=== Testing TFRecordReader.read op in a session... ===") - with tf.Graph().as_default() as g: + with tf.Graph().as_default(): filename_queue = tf.train.string_input_producer([input_path], num_epochs=1) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) @@ -249,3 +255,7 @@ if __name__ == "__main__": create_dir_test() create_object_test() + + +if __name__ == "__main__": + tf.app.run(main) |