diff options
author | Igor Saprykin <isaprykin@google.com> | 2017-07-06 15:59:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-06 16:03:23 -0700 |
commit | eb8bebd6709374dc6961e350d9988b8c15b42d31 (patch) | |
tree | 09ed5f44136e17d453196cf4e746986c051738d3 /tensorflow/examples/tutorials | |
parent | f9c9cacb06964f68737bf78f51e61452a3b480a8 (diff) |
Fix ValueError("None values not supported.") in the Abalone tutorial.
PiperOrigin-RevId: 161143481
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r-- | tensorflow/examples/tutorials/estimators/abalone.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/examples/tutorials/estimators/abalone.py b/tensorflow/examples/tutorials/estimators/abalone.py index 4765d5dabf..737b3ee5d6 100644 --- a/tensorflow/examples/tutorials/estimators/abalone.py +++ b/tensorflow/examples/tutorials/estimators/abalone.py @@ -87,25 +87,30 @@ def model_fn(features, labels, mode, params): # Reshape output layer to 1-dim Tensor to return predictions predictions = tf.reshape(output_layer, [-1]) - predictions_dict = {"ages": predictions} + + # Provide an estimator spec for `ModeKeys.PREDICT`. + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec( + mode=mode, + predictions={"ages": predictions}) # Calculate loss using mean squared error loss = tf.losses.mean_squared_error(labels, predictions) + optimizer = tf.train.GradientDescentOptimizer( + learning_rate=params["learning_rate"]) + train_op = optimizer.minimize( + loss=loss, global_step=tf.train.get_global_step()) + # Calculate root mean squared error as additional eval metric eval_metric_ops = { "rmse": tf.metrics.root_mean_squared_error( tf.cast(labels, tf.float64), predictions) } - optimizer = tf.train.GradientDescentOptimizer( - learning_rate=params["learning_rate"]) - train_op = optimizer.minimize( - loss=loss, global_step=tf.train.get_global_step()) - + # Provide an estimator spec for `ModeKeys.EVAL` and `ModeKeys.TRAIN` modes. return tf.estimator.EstimatorSpec( mode=mode, - predictions=predictions_dict, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops) |