aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/gcs_test
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-27 12:16:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-27 12:26:26 -0700
commit9158f974a346b4fae89044d8724eb052d466112b (patch)
treeb27152ad24984c8059f34cc708aa35a6fe430541 /tensorflow/tools/gcs_test
parent3d39b32b9a7833807dad037c3f57c818e9251f85 (diff)
Use tf.app.run in gcs_smoke, so that the flags are explicitly parsed, instead of parsed when first accessed.
PiperOrigin-RevId: 173702828
Diffstat (limited to 'tensorflow/tools/gcs_test')
-rw-r--r--tensorflow/tools/gcs_test/python/gcs_smoke.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/tools/gcs_test/python/gcs_smoke.py b/tensorflow/tools/gcs_test/python/gcs_smoke.py
index 9882f75a8a..ad4cb17ae1 100644
--- a/tensorflow/tools/gcs_test/python/gcs_smoke.py
+++ b/tensorflow/tools/gcs_test/python/gcs_smoke.py
@@ -35,6 +35,7 @@ flags.DEFINE_integer("num_examples", 10, "Number of examples to generate")
FLAGS = flags.FLAGS
+
def create_examples(num_examples, input_mean):
"""Create ExampleProto's containing data."""
ids = np.arange(num_examples).reshape([num_examples, 1])
@@ -49,6 +50,7 @@ def create_examples(num_examples, input_mean):
examples.append(ex)
return examples
+
def create_dir_test():
"""Verifies file_io directory handling methods."""
@@ -122,6 +124,7 @@ def create_dir_test():
print("Deleted directory recursively %s in %s milliseconds" % (
dir_name, elapsed_ms))
+
def create_object_test():
"""Verifies file_io's object manipulation methods ."""
starttime_ms = int(round(time.time() * 1000))
@@ -142,7 +145,8 @@ def create_object_test():
print("Creating file %s." % file_name)
file_io.write_string_to_file(file_name, "test file creation.")
elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
- print("Created %d files in %s milliseconds" % (len(files_to_create), elapsed_ms))
+ print("Created %d files in %s milliseconds" % (
+ len(files_to_create), elapsed_ms))
# Listing files of pattern1.
list_files_pattern = "%s/test_file*.txt" % dir_name
@@ -185,7 +189,9 @@ def create_object_test():
file_io.delete_recursively(dir_name)
-if __name__ == "__main__":
+def main(argv):
+ del argv # Unused.
+
# Sanity check on the GCS bucket URL.
if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
@@ -210,7 +216,7 @@ if __name__ == "__main__":
# tf_record_iterator works.
record_iter = tf.python_io.tf_record_iterator(input_path)
read_count = 0
- for r in record_iter:
+ for _ in record_iter:
read_count += 1
print("Read %d records using tf_record_iterator" % read_count)
@@ -222,7 +228,7 @@ if __name__ == "__main__":
# Verify that running the read op in a session works.
print("\n=== Testing TFRecordReader.read op in a session... ===")
- with tf.Graph().as_default() as g:
+ with tf.Graph().as_default():
filename_queue = tf.train.string_input_producer([input_path], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
@@ -249,3 +255,7 @@ if __name__ == "__main__":
create_dir_test()
create_object_test()
+
+
+if __name__ == "__main__":
+ tf.app.run(main)