aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/udacity
diff options
context:
space:
mode:
authorGravatar mlucool <mlucool@gmail.com>2017-02-02 00:56:57 -0500
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2017-02-01 21:56:57 -0800
commita81c968d7484b04ef6bffe6aef235d1d00c8a81a (patch)
treea0e271db5d7c8248aa0718808e12674f3a05933b /tensorflow/examples/udacity
parent4f5f61c6f06e014694395eb489473cfb19f92149 (diff)
Allow for data stored in an arbitrary location (#7200)
If you want to store data somewhere you have more space or be able to ignore a folder in sync, change the data_root.
Diffstat (limited to 'tensorflow/examples/udacity')
-rw-r--r--tensorflow/examples/udacity/1_notmnist.ipynb16
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/examples/udacity/1_notmnist.ipynb b/tensorflow/examples/udacity/1_notmnist.ipynb
index c9ec86f71a..4b0a20b1dd 100644
--- a/tensorflow/examples/udacity/1_notmnist.ipynb
+++ b/tensorflow/examples/udacity/1_notmnist.ipynb
@@ -111,6 +111,7 @@
"source": [
"url = 'http://commondatastorage.googleapis.com/books1000/'\n",
"last_percent_reported = None\n",
+ "data_root = '.' # Change me to store data elsewhere\n",
"\n",
"def download_progress_hook(count, blockSize, totalSize):\n",
" \"\"\"A hook to report the progress of a download. This is mostly intended for users with\n",
@@ -131,17 +132,18 @@
" \n",
"def maybe_download(filename, expected_bytes, force=False):\n",
" \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n",
- " if force or not os.path.exists(filename):\n",
+ " dest_filename = os.path.join(data_root, filename)\n",
+ " if force or not os.path.exists(dest_filename):\n",
" print('Attempting to download:', filename) \n",
- " filename, _ = urlretrieve(url + filename, filename, reporthook=download_progress_hook)\n",
+ " filename, _ = urlretrieve(url + filename, dest_filename, reporthook=download_progress_hook)\n",
" print('\\nDownload Complete!')\n",
- " statinfo = os.stat(filename)\n",
+ " statinfo = os.stat(dest_filename)\n",
" if statinfo.st_size == expected_bytes:\n",
- " print('Found and verified', filename)\n",
+ " print('Found and verified', dest_filename)\n",
" else:\n",
" raise Exception(\n",
- " 'Failed to verify ' + filename + '. Can you get to it with a browser?')\n",
- " return filename\n",
+ " 'Failed to verify ' + dest_filename + '. Can you get to it with a browser?')\n",
+ " return dest_filename\n",
"\n",
"train_filename = maybe_download('notMNIST_large.tar.gz', 247336696)\n",
"test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)"
@@ -683,7 +685,7 @@
"cellView": "both"
},
"source": [
- "pickle_file = 'notMNIST.pickle'\n",
+ "pickle_file = os.path.join(data_root, 'notMNIST.pickle')\n",
"\n",
"try:\n",
" f = open(pickle_file, 'wb')\n",