aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples
diff options
context:
space:
mode:
authorGravatar Dandelion Mané <dandelion@google.com>2017-12-11 17:02:01 -0800
committerGravatar Dandelion Mané <dandelion@google.com>2017-12-11 17:02:01 -0800
commitee09e0f0bedb45776ba0369aaec94814daca6452 (patch)
tree155ee610051daa60a75469dad0ad8b6447c9f2f8 /tensorflow/examples
parentabd5375ba8d373045321d1eebdb4501c36ab0ccd (diff)
parent634515e14e8bf5aa4bdfe149b77c9aa53383891e (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/examples')
-rw-r--r--tensorflow/examples/tutorials/layers/cnn_mnist.py4
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist.py4
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_deep.py10
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_softmax.py11
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py13
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_with_summaries.py13
6 files changed, 24 insertions, 31 deletions
diff --git a/tensorflow/examples/tutorials/layers/cnn_mnist.py b/tensorflow/examples/tutorials/layers/cnn_mnist.py
index 2124843fcb..1e8d7d05e1 100644
--- a/tensorflow/examples/tutorials/layers/cnn_mnist.py
+++ b/tensorflow/examples/tutorials/layers/cnn_mnist.py
@@ -97,9 +97,7 @@ def cnn_model_fn(features, labels, mode):
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Calculate Loss (for both TRAIN and EVAL modes)
- onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)
- loss = tf.losses.softmax_cross_entropy(
- onehot_labels=onehot_labels, logits=logits)
+ loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
diff --git a/tensorflow/examples/tutorials/mnist/mnist.py b/tensorflow/examples/tutorials/mnist/mnist.py
index 3585043a2a..7cedd0e264 100644
--- a/tensorflow/examples/tutorials/mnist/mnist.py
+++ b/tensorflow/examples/tutorials/mnist/mnist.py
@@ -94,9 +94,7 @@ def loss(logits, labels):
loss: Loss tensor of type float.
"""
labels = tf.to_int64(labels)
- cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
- labels=labels, logits=logits, name='xentropy')
- return tf.reduce_mean(cross_entropy, name='xentropy_mean')
+ return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
def training(loss, learning_rate):
diff --git a/tensorflow/examples/tutorials/mnist/mnist_deep.py b/tensorflow/examples/tutorials/mnist/mnist_deep.py
index a4dbab5123..1e0294db27 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_deep.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_deep.py
@@ -125,27 +125,27 @@ def bias_variable(shape):
def main(_):
# Import data
- mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
+ mnist = input_data.read_data_sets(FLAGS.data_dir)
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
# Define loss and optimizer
- y_ = tf.placeholder(tf.float32, [None, 10])
+ y_ = tf.placeholder(tf.int64, [None])
# Build the graph for the deep net
y_conv, keep_prob = deepnn(x)
with tf.name_scope('loss'):
- cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_,
- logits=y_conv)
+ cross_entropy = tf.losses.sparse_softmax_cross_entropy(
+ labels=y_, logits=y_conv)
cross_entropy = tf.reduce_mean(cross_entropy)
with tf.name_scope('adam_optimizer'):
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
with tf.name_scope('accuracy'):
- correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
+ correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
correct_prediction = tf.cast(correct_prediction, tf.float32)
accuracy = tf.reduce_mean(correct_prediction)
diff --git a/tensorflow/examples/tutorials/mnist/mnist_softmax.py b/tensorflow/examples/tutorials/mnist/mnist_softmax.py
index addd2d3810..fb3ac94203 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_softmax.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_softmax.py
@@ -34,7 +34,7 @@ FLAGS = None
def main(_):
# Import data
- mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
+ mnist = input_data.read_data_sets(FLAGS.data_dir)
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
@@ -43,7 +43,7 @@ def main(_):
y = tf.matmul(x, W) + b
# Define loss and optimizer
- y_ = tf.placeholder(tf.float32, [None, 10])
+ y_ = tf.placeholder(tf.int64, [None])
# The raw formulation of cross-entropy,
#
@@ -52,10 +52,9 @@ def main(_):
#
# can be numerically unstable.
#
- # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
+ # So here we use tf.losses.sparse_softmax_cross_entropy on the raw
# outputs of 'y', and then average across the batch.
- cross_entropy = tf.reduce_mean(
- tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
+ cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.InteractiveSession()
@@ -66,7 +65,7 @@ def main(_):
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# Test trained model
- correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
+ correct_prediction = tf.equal(tf.argmax(y, 1), y_)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels}))
diff --git a/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py b/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py
index eaff05913a..e89317494f 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py
@@ -32,7 +32,7 @@ FLAGS = None
def main(_):
# Import data
- mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
+ mnist = input_data.read_data_sets(FLAGS.data_dir)
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
@@ -41,7 +41,7 @@ def main(_):
y = tf.matmul(x, w) + b
# Define loss and optimizer
- y_ = tf.placeholder(tf.float32, [None, 10])
+ y_ = tf.placeholder(tf.int64, [None])
# The raw formulation of cross-entropy,
#
@@ -50,10 +50,9 @@ def main(_):
#
# can be numerically unstable.
#
- # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
- # outputs of 'y', and then average across the batch.
- cross_entropy = tf.reduce_mean(
- tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
+ # So here we use tf.losses.sparse_softmax_cross_entropy on the raw
+ # logit outputs of 'y', and then average across the batch.
+ cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
config = tf.ConfigProto()
@@ -86,7 +85,7 @@ def main(_):
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# Test trained model
- correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
+ correct_prediction = tf.equal(tf.argmax(y, 1), y_)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy,
feed_dict={x: mnist.test.images,
diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
index c401d09df8..7967e22d6a 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
@@ -38,7 +38,6 @@ FLAGS = None
def train():
# Import data
mnist = input_data.read_data_sets(FLAGS.data_dir,
- one_hot=True,
fake_data=FLAGS.fake_data)
sess = tf.InteractiveSession()
@@ -47,7 +46,7 @@ def train():
# Input placeholders
with tf.name_scope('input'):
x = tf.placeholder(tf.float32, [None, 784], name='x-input')
- y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
+ y_ = tf.placeholder(tf.int64, [None], name='y-input')
with tf.name_scope('input_reshape'):
image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
@@ -117,12 +116,12 @@ def train():
#
# can be numerically unstable.
#
- # So here we use tf.nn.softmax_cross_entropy_with_logits on the
- # raw outputs of the nn_layer above, and then average across
+ # So here we use tf.losses.sparse_softmax_cross_entropy on the
+ # raw logit outputs of the nn_layer above, and then average across
# the batch.
- diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)
with tf.name_scope('total'):
- cross_entropy = tf.reduce_mean(diff)
+ cross_entropy = tf.losses.sparse_softmax_cross_entropy(
+ labels=y_, logits=y)
tf.summary.scalar('cross_entropy', cross_entropy)
with tf.name_scope('train'):
@@ -131,7 +130,7 @@ def train():
with tf.name_scope('accuracy'):
with tf.name_scope('correct_prediction'):
- correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
+ correct_prediction = tf.equal(tf.argmax(y, 1), y_)
with tf.name_scope('accuracy'):
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', accuracy)