aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/tutorials
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-05-10 21:12:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 11:02:28 -0700
commitee112cff56081fb9d0b74c987a8935acc360b05c (patch)
tree6026d8b42ccc09d9c0d1b2d091916cfcb4f5a057 /tensorflow/examples/tutorials
parent27c89207d2f31fe4b4b42c789b96d62cde4e2133 (diff)
Merge changes from github.
PiperOrigin-RevId: 155709893
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r--tensorflow/examples/tutorials/estimators/abalone.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/examples/tutorials/estimators/abalone.py b/tensorflow/examples/tutorials/estimators/abalone.py
index 932ce8a8b2..3c0ea2e409 100644
--- a/tensorflow/examples/tutorials/estimators/abalone.py
+++ b/tensorflow/examples/tutorials/estimators/abalone.py
@@ -134,12 +134,22 @@ def main(unused_argv):
# Instantiate Estimator
nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params)
-
+
+ def get_train_inputs():
+ x = tf.constant(training_set.data)
+ y = tf.constant(training_set.target)
+ return x, y
+
# Fit
- nn.fit(x=training_set.data, y=training_set.target, steps=5000)
+ nn.fit(input_fn=get_train_inputs, steps=5000)
# Score accuracy
- ev = nn.evaluate(x=test_set.data, y=test_set.target, steps=1)
+ def get_test_inputs():
+ x = tf.constant(test_set.data)
+ y = tf.constant(test_set.target)
+ return x, y
+
+ ev = nn.evaluate(input_fn=get_test_inputs, steps=1)
print("Loss: %s" % ev["loss"])
print("Root Mean Squared Error: %s" % ev["rmse"])