aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/datasets/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/datasets/base.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/base.py16
1 files changed, 7 insertions, 9 deletions
diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py
index cdff6baf83..71978d4394 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/base.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/base.py
@@ -186,8 +186,8 @@ def _is_retriable(e):
@retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable)
-def urlretrieve_with_retry(url, filename):
- urllib.request.urlretrieve(url, filename)
+def urlretrieve_with_retry(url, filename=None):
+ return urllib.request.urlretrieve(url, filename)
def maybe_download(filename, work_directory, source_url):
@@ -205,11 +205,9 @@ def maybe_download(filename, work_directory, source_url):
gfile.MakeDirs(work_directory)
filepath = os.path.join(work_directory, filename)
if not gfile.Exists(filepath):
- with tempfile.NamedTemporaryFile() as tmpfile:
- temp_file_name = tmpfile.name
- urlretrieve_with_retry(source_url, temp_file_name)
- gfile.Copy(temp_file_name, filepath)
- with gfile.GFile(filepath) as f:
- size = f.size()
- print('Successfully downloaded', filename, size, 'bytes.')
+ temp_file_name, _ = urlretrieve_with_retry(source_url)
+ gfile.Copy(temp_file_name, filepath)
+ with gfile.GFile(filepath) as f:
+ size = f.size()
+ print('Successfully downloaded', filename, size, 'bytes.')
return filepath