aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples
diff options
context:
space:
mode:
authorGravatar Mark Daoust <markdaoust@google.com>2017-12-07 09:49:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 10:00:33 -0800
commit0d160ab43fcaa8357ce9eff6795dc30a41100175 (patch)
treeb410754fa2a790cd8ee5160d896ae362ec28cd9a /tensorflow/examples
parent1730f9743c6a57beee8158bc35c689d24c8df729 (diff)
Clear softmax_v2 warning for image_retraining and speech_commands tutorials.
`tf.nn.softmax_cross_entropy_with_logits` and `tf.losses.softmax_cross_entropy` both throw the warning. Almost everywhere it's used can simply be replaced by `tf.losses.sparse_softmax_cross_entropy` PiperOrigin-RevId: 178253804
Diffstat (limited to 'tensorflow/examples')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py26
-rw-r--r--tensorflow/examples/image_retraining/retrain_test.py2
-rw-r--r--tensorflow/examples/speech_commands/input_data.py7
-rw-r--r--tensorflow/examples/speech_commands/train.py12
4 files changed, 17 insertions, 30 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index ebddfb20f4..ec22684eaf 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -539,10 +539,8 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
sess, image_lists, label_name, image_index, image_dir, category,
bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
resized_input_tensor, bottleneck_tensor, architecture)
- ground_truth = np.zeros(class_count, dtype=np.float32)
- ground_truth[label_index] = 1.0
bottlenecks.append(bottleneck)
- ground_truths.append(ground_truth)
+ ground_truths.append(label_index)
filenames.append(image_name)
else:
# Retrieve all bottlenecks.
@@ -555,10 +553,8 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
sess, image_lists, label_name, image_index, image_dir, category,
bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
resized_input_tensor, bottleneck_tensor, architecture)
- ground_truth = np.zeros(class_count, dtype=np.float32)
- ground_truth[label_index] = 1.0
bottlenecks.append(bottleneck)
- ground_truths.append(ground_truth)
+ ground_truths.append(label_index)
filenames.append(image_name)
return bottlenecks, ground_truths, filenames
@@ -610,10 +606,8 @@ def get_random_distorted_bottlenecks(
bottleneck_values = sess.run(bottleneck_tensor,
{resized_input_tensor: distorted_image_data})
bottleneck_values = np.squeeze(bottleneck_values)
- ground_truth = np.zeros(class_count, dtype=np.float32)
- ground_truth[label_index] = 1.0
bottlenecks.append(bottleneck_values)
- ground_truths.append(ground_truth)
+ ground_truths.append(label_index)
return bottlenecks, ground_truths
@@ -774,9 +768,8 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
shape=[None, bottleneck_tensor_size],
name='BottleneckInputPlaceholder')
- ground_truth_input = tf.placeholder(tf.float32,
- [None, class_count],
- name='GroundTruthInput')
+ ground_truth_input = tf.placeholder(
+ tf.int64, [None], name='GroundTruthInput')
# Organizing the following ops as `final_training_ops` so they're easier
# to see in TensorBoard
@@ -823,10 +816,8 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
tf.summary.histogram('activations', final_tensor)
with tf.name_scope('cross_entropy'):
- cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
+ cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
labels=ground_truth_input, logits=logits)
- with tf.name_scope('total'):
- cross_entropy_mean = tf.reduce_mean(cross_entropy)
tf.summary.scalar('cross_entropy', cross_entropy_mean)
@@ -852,8 +843,7 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
with tf.name_scope('accuracy'):
with tf.name_scope('correct_prediction'):
prediction = tf.argmax(result_tensor, 1)
- correct_prediction = tf.equal(
- prediction, tf.argmax(ground_truth_tensor, 1))
+ correct_prediction = tf.equal(prediction, ground_truth_tensor)
with tf.name_scope('accuracy'):
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', evaluation_step)
@@ -1178,7 +1168,7 @@ def main(_):
if FLAGS.print_misclassified_test_images:
tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===')
for i, test_filename in enumerate(test_filenames):
- if predictions[i] != test_ground_truth[i].argmax():
+ if predictions[i] != test_ground_truth[i]:
tf.logging.info('%70s %s' %
(test_filename,
list(image_lists.keys())[predictions[i]]))
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py
index 2de4c4ec99..8b8dd45fd7 100644
--- a/tensorflow/examples/image_retraining/retrain_test.py
+++ b/tensorflow/examples/image_retraining/retrain_test.py
@@ -87,7 +87,7 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
def testAddEvaluationStep(self):
with tf.Graph().as_default():
final = tf.placeholder(tf.float32, [1], name='final')
- gt = tf.placeholder(tf.float32, [1], name='gt')
+ gt = tf.placeholder(tf.int64, [1], name='gt')
self.assertIsNotNone(retrain.add_evaluation_step(final, gt))
def testAddJpegDecoding(self):
diff --git a/tensorflow/examples/speech_commands/input_data.py b/tensorflow/examples/speech_commands/input_data.py
index 751652b330..e7db9cddf0 100644
--- a/tensorflow/examples/speech_commands/input_data.py
+++ b/tensorflow/examples/speech_commands/input_data.py
@@ -417,8 +417,7 @@ class AudioProcessor(object):
sess: TensorFlow session that was active when processor was created.
Returns:
- List of sample data for the transformed samples, and list of labels in
- one-hot form.
+ List of sample data for the transformed samples, and list of label indexes
"""
# Pick one of the partitions to choose samples from.
candidates = self.data_index[mode]
@@ -428,7 +427,7 @@ class AudioProcessor(object):
sample_count = max(0, min(how_many, len(candidates) - offset))
# Data and labels will be populated and returned.
data = np.zeros((sample_count, model_settings['fingerprint_size']))
- labels = np.zeros((sample_count, model_settings['label_count']))
+ labels = np.zeros(sample_count)
desired_samples = model_settings['desired_samples']
use_background = self.background_data and (mode == 'training')
pick_deterministically = (mode != 'training')
@@ -483,7 +482,7 @@ class AudioProcessor(object):
# Run the graph to produce the output audio.
data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten()
label_index = self.word_to_index[sample['label']]
- labels[i - offset, label_index] = 1
+ labels[i - offset] = label_index
return data, labels
def get_unprocessed_data(self, how_many, model_settings, mode):
diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py
index bec7dacd21..a4e80041f8 100644
--- a/tensorflow/examples/speech_commands/train.py
+++ b/tensorflow/examples/speech_commands/train.py
@@ -133,7 +133,7 @@ def main(_):
# Define loss and optimizer
ground_truth_input = tf.placeholder(
- tf.float32, [None, label_count], name='groundtruth_input')
+ tf.int64, [None], name='groundtruth_input')
# Optionally we can add runtime checks to spot when NaNs or other symptoms of
# numerical errors start occurring during training.
@@ -144,9 +144,8 @@ def main(_):
# Create the back propagation and training evaluation machinery in the graph.
with tf.name_scope('cross_entropy'):
- cross_entropy_mean = tf.reduce_mean(
- tf.nn.softmax_cross_entropy_with_logits(
- labels=ground_truth_input, logits=logits))
+ cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
+ labels=ground_truth_input, logits=logits)
tf.summary.scalar('cross_entropy', cross_entropy_mean)
with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
learning_rate_input = tf.placeholder(
@@ -154,10 +153,9 @@ def main(_):
train_step = tf.train.GradientDescentOptimizer(
learning_rate_input).minimize(cross_entropy_mean)
predicted_indices = tf.argmax(logits, 1)
- expected_indices = tf.argmax(ground_truth_input, 1)
- correct_prediction = tf.equal(predicted_indices, expected_indices)
+ correct_prediction = tf.equal(predicted_indices, ground_truth_input)
confusion_matrix = tf.confusion_matrix(
- expected_indices, predicted_indices, num_classes=label_count)
+ ground_truth_input, predicted_indices, num_classes=label_count)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', evaluation_step)