diff options
author | Mark Daoust <markdaoust@google.com> | 2017-12-07 09:49:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-07 10:00:33 -0800 |
commit | 0d160ab43fcaa8357ce9eff6795dc30a41100175 (patch) | |
tree | b410754fa2a790cd8ee5160d896ae362ec28cd9a /tensorflow/examples | |
parent | 1730f9743c6a57beee8158bc35c689d24c8df729 (diff) |
Clear softmax_v2 warning for image_retraining and speech_commands tutorials.
`tf.nn.softmax_cross_entropy_with_logits` and `tf.losses.softmax_cross_entropy` both throw the warning.
Almost everywhere it's used can simply be replaced by `tf.losses.sparse_softmax_cross_entropy`
PiperOrigin-RevId: 178253804
Diffstat (limited to 'tensorflow/examples')
-rw-r--r-- | tensorflow/examples/image_retraining/retrain.py | 26 | ||||
-rw-r--r-- | tensorflow/examples/image_retraining/retrain_test.py | 2 | ||||
-rw-r--r-- | tensorflow/examples/speech_commands/input_data.py | 7 | ||||
-rw-r--r-- | tensorflow/examples/speech_commands/train.py | 12 |
4 files changed, 17 insertions, 30 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index ebddfb20f4..ec22684eaf 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -539,10 +539,8 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category, sess, image_lists, label_name, image_index, image_dir, category, bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, resized_input_tensor, bottleneck_tensor, architecture) - ground_truth = np.zeros(class_count, dtype=np.float32) - ground_truth[label_index] = 1.0 bottlenecks.append(bottleneck) - ground_truths.append(ground_truth) + ground_truths.append(label_index) filenames.append(image_name) else: # Retrieve all bottlenecks. @@ -555,10 +553,8 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category, sess, image_lists, label_name, image_index, image_dir, category, bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, resized_input_tensor, bottleneck_tensor, architecture) - ground_truth = np.zeros(class_count, dtype=np.float32) - ground_truth[label_index] = 1.0 bottlenecks.append(bottleneck) - ground_truths.append(ground_truth) + ground_truths.append(label_index) filenames.append(image_name) return bottlenecks, ground_truths, filenames @@ -610,10 +606,8 @@ def get_random_distorted_bottlenecks( bottleneck_values = sess.run(bottleneck_tensor, {resized_input_tensor: distorted_image_data}) bottleneck_values = np.squeeze(bottleneck_values) - ground_truth = np.zeros(class_count, dtype=np.float32) - ground_truth[label_index] = 1.0 bottlenecks.append(bottleneck_values) - ground_truths.append(ground_truth) + ground_truths.append(label_index) return bottlenecks, ground_truths @@ -774,9 +768,8 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, shape=[None, bottleneck_tensor_size], name='BottleneckInputPlaceholder') - ground_truth_input = tf.placeholder(tf.float32, - [None, class_count], - name='GroundTruthInput') + ground_truth_input = tf.placeholder( + tf.int64, [None], name='GroundTruthInput') # Organizing the following ops as `final_training_ops` so they're easier # to see in TensorBoard @@ -823,10 +816,8 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, tf.summary.histogram('activations', final_tensor) with tf.name_scope('cross_entropy'): - cross_entropy = tf.nn.softmax_cross_entropy_with_logits( + cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( labels=ground_truth_input, logits=logits) - with tf.name_scope('total'): - cross_entropy_mean = tf.reduce_mean(cross_entropy) tf.summary.scalar('cross_entropy', cross_entropy_mean) @@ -852,8 +843,7 @@ def add_evaluation_step(result_tensor, ground_truth_tensor): with tf.name_scope('accuracy'): with tf.name_scope('correct_prediction'): prediction = tf.argmax(result_tensor, 1) - correct_prediction = tf.equal( - prediction, tf.argmax(ground_truth_tensor, 1)) + correct_prediction = tf.equal(prediction, ground_truth_tensor) with tf.name_scope('accuracy'): evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) @@ -1178,7 +1168,7 @@ def main(_): if FLAGS.print_misclassified_test_images: tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===') for i, test_filename in enumerate(test_filenames): - if predictions[i] != test_ground_truth[i].argmax(): + if predictions[i] != test_ground_truth[i]: tf.logging.info('%70s %s' % (test_filename, list(image_lists.keys())[predictions[i]])) diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py index 2de4c4ec99..8b8dd45fd7 100644 --- a/tensorflow/examples/image_retraining/retrain_test.py +++ b/tensorflow/examples/image_retraining/retrain_test.py @@ -87,7 +87,7 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase): def testAddEvaluationStep(self): with tf.Graph().as_default(): final = tf.placeholder(tf.float32, [1], name='final') - gt = tf.placeholder(tf.float32, [1], name='gt') + gt = tf.placeholder(tf.int64, [1], name='gt') self.assertIsNotNone(retrain.add_evaluation_step(final, gt)) def testAddJpegDecoding(self): diff --git a/tensorflow/examples/speech_commands/input_data.py b/tensorflow/examples/speech_commands/input_data.py index 751652b330..e7db9cddf0 100644 --- a/tensorflow/examples/speech_commands/input_data.py +++ b/tensorflow/examples/speech_commands/input_data.py @@ -417,8 +417,7 @@ class AudioProcessor(object): sess: TensorFlow session that was active when processor was created. Returns: - List of sample data for the transformed samples, and list of labels in - one-hot form. + List of sample data for the transformed samples, and list of label indexes """ # Pick one of the partitions to choose samples from. candidates = self.data_index[mode] @@ -428,7 +427,7 @@ class AudioProcessor(object): sample_count = max(0, min(how_many, len(candidates) - offset)) # Data and labels will be populated and returned. data = np.zeros((sample_count, model_settings['fingerprint_size'])) - labels = np.zeros((sample_count, model_settings['label_count'])) + labels = np.zeros(sample_count) desired_samples = model_settings['desired_samples'] use_background = self.background_data and (mode == 'training') pick_deterministically = (mode != 'training') @@ -483,7 +482,7 @@ class AudioProcessor(object): # Run the graph to produce the output audio. data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten() label_index = self.word_to_index[sample['label']] - labels[i - offset, label_index] = 1 + labels[i - offset] = label_index return data, labels def get_unprocessed_data(self, how_many, model_settings, mode): diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py index bec7dacd21..a4e80041f8 100644 --- a/tensorflow/examples/speech_commands/train.py +++ b/tensorflow/examples/speech_commands/train.py @@ -133,7 +133,7 @@ def main(_): # Define loss and optimizer ground_truth_input = tf.placeholder( - tf.float32, [None, label_count], name='groundtruth_input') + tf.int64, [None], name='groundtruth_input') # Optionally we can add runtime checks to spot when NaNs or other symptoms of # numerical errors start occurring during training. @@ -144,9 +144,8 @@ def main(_): # Create the back propagation and training evaluation machinery in the graph. with tf.name_scope('cross_entropy'): - cross_entropy_mean = tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits( - labels=ground_truth_input, logits=logits)) + cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( + labels=ground_truth_input, logits=logits) tf.summary.scalar('cross_entropy', cross_entropy_mean) with tf.name_scope('train'), tf.control_dependencies(control_dependencies): learning_rate_input = tf.placeholder( @@ -154,10 +153,9 @@ def main(_): train_step = tf.train.GradientDescentOptimizer( learning_rate_input).minimize(cross_entropy_mean) predicted_indices = tf.argmax(logits, 1) - expected_indices = tf.argmax(ground_truth_input, 1) - correct_prediction = tf.equal(predicted_indices, expected_indices) + correct_prediction = tf.equal(predicted_indices, ground_truth_input) confusion_matrix = tf.confusion_matrix( - expected_indices, predicted_indices, num_classes=label_count) + ground_truth_input, predicted_indices, num_classes=label_count) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) |