diff options
author | Vincent Vanhoucke <vanhoucke@google.com> | 2016-01-26 11:35:23 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2016-01-26 11:41:04 -0800 |
commit | 6e5a7d9af85d67ec1dd499b472472b68a56d51a2 (patch) | |
tree | 61380db73a2854d345772e291569f00f2e34936c | |
parent | b2f0bc2e230dcd690e7cf34e5425f0f499d9557b (diff) |
Updates to 1st Udacity colab:
- re-make data extraction idempotent in case someone accidentally cancels the extraction.
- Python 3 compatibility (h/t @lzlarryli)
Change: 113079245
-rw-r--r-- | tensorflow/examples/udacity/1_notmnist.ipynb | 41 |
1 files changed, 21 insertions, 20 deletions
diff --git a/tensorflow/examples/udacity/1_notmnist.ipynb b/tensorflow/examples/udacity/1_notmnist.ipynb index 6439c45314..e44622b48a 100644 --- a/tensorflow/examples/udacity/1_notmnist.ipynb +++ b/tensorflow/examples/udacity/1_notmnist.ipynb @@ -48,12 +48,13 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import os\n", + "import sys\n", "import tarfile\n", - "import urllib\n", "from IPython.display import display, Image\n", "from scipy import ndimage\n", "from sklearn.linear_model import LogisticRegression\n", - "import cPickle as pickle" + "from six.moves.urllib.request import urlretrieve\n", + "from six.moves import cPickle as pickle" ], "outputs": [], "execution_count": 0 @@ -185,17 +186,17 @@ "def extract(filename):\n", " tar = tarfile.open(filename)\n", " root = os.path.splitext(os.path.splitext(filename)[0])[0] # remove .tar.gz\n", - " if not os.path.isdir(root):\n", - " print 'Extracting data for %s. This may take a while. Please wait.' % root\n", - " sys.stdout.flush()\n", - " tar.extractall()\n", - " tar.close()\n", - " data_folders = [os.path.join(root, d) for d in sorted(os.listdir(root)) if d != '.DS_Store']\n", + " print('Extracting data for %s. This may take a while. Please wait.' % root)\n", + " sys.stdout.flush()\n", + " tar.extractall()\n", + " tar.close()\n", + " data_folders = [\n", + " os.path.join(root, d) for d in sorted(os.listdir(root)) if d != '.DS_Store']\n", " if len(data_folders) != num_classes:\n", " raise Exception(\n", " 'Expected %d folders, one per class. Found %d instead.' % (\n", " num_classes, len(data_folders)))\n", - " print data_folders\n", + " print(data_folders)\n", " return data_folders\n", " \n", "train_folders = extract(train_filename)\n", @@ -289,7 +290,7 @@ " label_index = 0\n", " image_index = 0\n", " for folder in data_folders:\n", - " print folder\n", + " print(folder)\n", " for image in os.listdir(folder):\n", " if image_index >= max_num_images:\n", " raise Exception('More images than expected: %d >= %d' % (\n", @@ -304,7 +305,7 @@ " labels[image_index] = label_index\n", " image_index += 1\n", " except IOError as e:\n", - " print 'Could not read:', image_file, ':', e, '- it\\'s ok, skipping.'\n", + " print('Could not read:', image_file, ':', e, '- it\\'s ok, skipping.')\n", " label_index += 1\n", " num_images = image_index\n", " dataset = dataset[0:num_images, :, :]\n", @@ -312,10 +313,10 @@ " if num_images < min_num_images:\n", " raise Exception('Many fewer images than expected: %d < %d' % (\n", " num_images, min_num_images))\n", - " print 'Full dataset tensor:', dataset.shape\n", - " print 'Mean:', np.mean(dataset)\n", - " print 'Standard deviation:', np.std(dataset)\n", - " print 'Labels:', labels.shape\n", + " print('Full dataset tensor:', dataset.shape)\n", + " print('Mean:', np.mean(dataset))\n", + " print('Standard deviation:', np.std(dataset))\n", + " print('Labels:', labels.shape)\n", " return dataset, labels\n", "train_dataset, train_labels = load(train_folders, 450000, 550000)\n", "test_dataset, test_labels = load(test_folders, 18000, 20000)" @@ -502,8 +503,8 @@ "valid_labels = train_labels[:valid_size]\n", "train_dataset = train_dataset[valid_size:valid_size+train_size,:,:]\n", "train_labels = train_labels[valid_size:valid_size+train_size]\n", - "print 'Training', train_dataset.shape, train_labels.shape\n", - "print 'Validation', valid_dataset.shape, valid_labels.shape" + "print('Training', train_dataset.shape, train_labels.shape)\n", + "print('Validation', valid_dataset.shape, valid_labels.shape)" ], "outputs": [ { @@ -556,7 +557,7 @@ " pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)\n", " f.close()\n", "except Exception as e:\n", - " print 'Unable to save data to', pickle_file, ':', e\n", + " print('Unable to save data to', pickle_file, ':', e)\n", " raise" ], "outputs": [], @@ -599,7 +600,7 @@ }, "source": [ "statinfo = os.stat(pickle_file)\n", - "print 'Compressed pickle size:', statinfo.st_size" + "print('Compressed pickle size:', statinfo.st_size)" ], "outputs": [ { @@ -653,4 +654,4 @@ ] } ] -} +}
\ No newline at end of file |