aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vincent Vanhoucke <vanhoucke@google.com>2016-01-26 11:35:23 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-26 11:41:04 -0800
commit6e5a7d9af85d67ec1dd499b472472b68a56d51a2 (patch)
tree61380db73a2854d345772e291569f00f2e34936c
parentb2f0bc2e230dcd690e7cf34e5425f0f499d9557b (diff)
Updates to 1st Udacity colab:
- re-make data extraction idempotent in case someone accidentally cancels the extraction. - Python 3 compatibility (h/t @lzlarryli) Change: 113079245
-rw-r--r--tensorflow/examples/udacity/1_notmnist.ipynb41
1 files changed, 21 insertions, 20 deletions
diff --git a/tensorflow/examples/udacity/1_notmnist.ipynb b/tensorflow/examples/udacity/1_notmnist.ipynb
index 6439c45314..e44622b48a 100644
--- a/tensorflow/examples/udacity/1_notmnist.ipynb
+++ b/tensorflow/examples/udacity/1_notmnist.ipynb
@@ -48,12 +48,13 @@
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
+ "import sys\n",
"import tarfile\n",
- "import urllib\n",
"from IPython.display import display, Image\n",
"from scipy import ndimage\n",
"from sklearn.linear_model import LogisticRegression\n",
- "import cPickle as pickle"
+ "from six.moves.urllib.request import urlretrieve\n",
+ "from six.moves import cPickle as pickle"
],
"outputs": [],
"execution_count": 0
@@ -185,17 +186,17 @@
"def extract(filename):\n",
" tar = tarfile.open(filename)\n",
" root = os.path.splitext(os.path.splitext(filename)[0])[0] # remove .tar.gz\n",
- " if not os.path.isdir(root):\n",
- " print 'Extracting data for %s. This may take a while. Please wait.' % root\n",
- " sys.stdout.flush()\n",
- " tar.extractall()\n",
- " tar.close()\n",
- " data_folders = [os.path.join(root, d) for d in sorted(os.listdir(root)) if d != '.DS_Store']\n",
+ " print('Extracting data for %s. This may take a while. Please wait.' % root)\n",
+ " sys.stdout.flush()\n",
+ " tar.extractall()\n",
+ " tar.close()\n",
+ " data_folders = [\n",
+ " os.path.join(root, d) for d in sorted(os.listdir(root)) if d != '.DS_Store']\n",
" if len(data_folders) != num_classes:\n",
" raise Exception(\n",
" 'Expected %d folders, one per class. Found %d instead.' % (\n",
" num_classes, len(data_folders)))\n",
- " print data_folders\n",
+ " print(data_folders)\n",
" return data_folders\n",
" \n",
"train_folders = extract(train_filename)\n",
@@ -289,7 +290,7 @@
" label_index = 0\n",
" image_index = 0\n",
" for folder in data_folders:\n",
- " print folder\n",
+ " print(folder)\n",
" for image in os.listdir(folder):\n",
" if image_index >= max_num_images:\n",
" raise Exception('More images than expected: %d >= %d' % (\n",
@@ -304,7 +305,7 @@
" labels[image_index] = label_index\n",
" image_index += 1\n",
" except IOError as e:\n",
- " print 'Could not read:', image_file, ':', e, '- it\\'s ok, skipping.'\n",
+ " print('Could not read:', image_file, ':', e, '- it\\'s ok, skipping.')\n",
" label_index += 1\n",
" num_images = image_index\n",
" dataset = dataset[0:num_images, :, :]\n",
@@ -312,10 +313,10 @@
" if num_images < min_num_images:\n",
" raise Exception('Many fewer images than expected: %d < %d' % (\n",
" num_images, min_num_images))\n",
- " print 'Full dataset tensor:', dataset.shape\n",
- " print 'Mean:', np.mean(dataset)\n",
- " print 'Standard deviation:', np.std(dataset)\n",
- " print 'Labels:', labels.shape\n",
+ " print('Full dataset tensor:', dataset.shape)\n",
+ " print('Mean:', np.mean(dataset))\n",
+ " print('Standard deviation:', np.std(dataset))\n",
+ " print('Labels:', labels.shape)\n",
" return dataset, labels\n",
"train_dataset, train_labels = load(train_folders, 450000, 550000)\n",
"test_dataset, test_labels = load(test_folders, 18000, 20000)"
@@ -502,8 +503,8 @@
"valid_labels = train_labels[:valid_size]\n",
"train_dataset = train_dataset[valid_size:valid_size+train_size,:,:]\n",
"train_labels = train_labels[valid_size:valid_size+train_size]\n",
- "print 'Training', train_dataset.shape, train_labels.shape\n",
- "print 'Validation', valid_dataset.shape, valid_labels.shape"
+ "print('Training', train_dataset.shape, train_labels.shape)\n",
+ "print('Validation', valid_dataset.shape, valid_labels.shape)"
],
"outputs": [
{
@@ -556,7 +557,7 @@
" pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)\n",
" f.close()\n",
"except Exception as e:\n",
- " print 'Unable to save data to', pickle_file, ':', e\n",
+ " print('Unable to save data to', pickle_file, ':', e)\n",
" raise"
],
"outputs": [],
@@ -599,7 +600,7 @@
},
"source": [
"statinfo = os.stat(pickle_file)\n",
- "print 'Compressed pickle size:', statinfo.st_size"
+ "print('Compressed pickle size:', statinfo.st_size)"
],
"outputs": [
{
@@ -653,4 +654,4 @@
]
}
]
-}
+} \ No newline at end of file