diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-12-12 12:56:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-12 13:03:07 -0800 |
commit | dc60657830f2ab7bf2ba69a7270b1964b9bcbda2 (patch) | |
tree | 89b5022430ad89d2f8f4e1b606e5fa5e06b22e52 | |
parent | db979cdd0a5da682f48cbe804cfb8ba8b9835e1a (diff) |
Squeezing labels before passing them into the top k eval metric for TensorForest. Adding eval_metrics tests.
Change: 141799551
-rw-r--r-- | tensorflow/contrib/tensor_forest/BUILD | 13 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/eval_metrics.py | 70 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/eval_metrics_test.py | 86 |
3 files changed, 138 insertions, 31 deletions
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 6828ef2222..e5ed22fdd3 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -124,6 +124,19 @@ py_library( srcs_version = "PY2AND3", ) +py_test( + name = "eval_metrics_test", + size = "small", + srcs = ["client/eval_metrics_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":eval_metrics", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + py_library( name = "client_lib", srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index c376e4557f..03ceb6f638 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -30,8 +30,10 @@ INFERENCE_PRED_NAME = 'predictions' def _top_k_generator(k): def _top_k(probabilities, targets): - return metric_ops.streaming_mean(nn.in_top_k(probabilities, - math_ops.to_int32(targets), k)) + targets = math_ops.to_int32(targets) + if targets.get_shape().ndims > 1: + targets = array_ops.squeeze(targets, squeeze_dims=[1]) + return metric_ops.streaming_mean(nn.in_top_k(probabilities, targets, k)) return _top_k @@ -45,8 +47,8 @@ def _r2(probabilities, targets, weights=None): targets = math_ops.to_float(targets) y_mean = math_ops.reduce_mean(targets, 0) squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0) - squares_residuals = math_ops.reduce_sum(math_ops.square( - targets - probabilities), 0) + squares_residuals = math_ops.reduce_sum( + math_ops.square(targets - probabilities), 0) score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) return metric_ops.streaming_mean(score, weights=weights) @@ -57,16 +59,19 @@ def _squeeze_and_onehot(targets, depth): def _sigmoid_entropy(probabilities, targets, weights=None): - return metric_ops.streaming_mean(losses.sigmoid_cross_entropy( - probabilities, _squeeze_and_onehot(targets, - array_ops.shape(probabilities)[1])), - weights=weights) + return metric_ops.streaming_mean( + losses.sigmoid_cross_entropy(probabilities, + _squeeze_and_onehot( + targets, + array_ops.shape(probabilities)[1])), + weights=weights) def _softmax_entropy(probabilities, targets, weights=None): - return metric_ops.streaming_mean(losses.sparse_softmax_cross_entropy( - probabilities, math_ops.to_int32(targets)), - weights=weights) + return metric_ops.streaming_mean( + losses.sparse_softmax_cross_entropy(probabilities, + math_ops.to_int32(targets)), + weights=weights) def _predictions(predictions, unused_targets, **unused_kwargs): @@ -89,26 +94,29 @@ def _recall(predictions, targets, weights=None): return metric_ops.streaming_recall(predictions, targets, weights=weights) -_EVAL_METRICS = {'sigmoid_entropy': _sigmoid_entropy, - 'softmax_entropy': _softmax_entropy, - 'accuracy': _accuracy, - 'r2': _r2, - 'predictions': _predictions, - 'top_5': _top_k_generator(5), - 'classification_log_loss': _class_log_loss, - 'precision': _precision, - 'recall': _recall} - - -_PREDICTION_KEYS = {'sigmoid_entropy': INFERENCE_PROB_NAME, - 'softmax_entropy': INFERENCE_PROB_NAME, - 'accuracy': INFERENCE_PRED_NAME, - 'r2': INFERENCE_PROB_NAME, - 'predictions': INFERENCE_PRED_NAME, - 'top_5': INFERENCE_PROB_NAME, - 'classification_log_loss': INFERENCE_PROB_NAME, - 'precision': INFERENCE_PRED_NAME, - 'recall': INFERENCE_PRED_NAME} +_EVAL_METRICS = { + 'sigmoid_entropy': _sigmoid_entropy, + 'softmax_entropy': _softmax_entropy, + 'accuracy': _accuracy, + 'r2': _r2, + 'predictions': _predictions, + 'top_5': _top_k_generator(5), + 'classification_log_loss': _class_log_loss, + 'precision': _precision, + 'recall': _recall +} + +_PREDICTION_KEYS = { + 'sigmoid_entropy': INFERENCE_PROB_NAME, + 'softmax_entropy': INFERENCE_PROB_NAME, + 'accuracy': INFERENCE_PRED_NAME, + 'r2': INFERENCE_PROB_NAME, + 'predictions': INFERENCE_PRED_NAME, + 'top_5': INFERENCE_PROB_NAME, + 'classification_log_loss': INFERENCE_PROB_NAME, + 'precision': INFERENCE_PRED_NAME, + 'recall': INFERENCE_PRED_NAME +} def get_metric(metric_name): diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py b/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py new file mode 100644 index 0000000000..be3ef1a822 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py @@ -0,0 +1,86 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.tensor_forest.client.eval_metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.tensor_forest.client import eval_metrics + +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class EvalMetricsTest(test_util.TensorFlowTestCase): + + def testTop2(self): + top_2_fn = eval_metrics._top_k_generator(2) + probabilities = tf.constant([[0.1, 0.2, 0.3], [0.4, 0.7, 0.5], + [0.9, 0.8, 0.2], [0.6, 0.4, 0.8]]) + targets = tf.constant([[0], [2], [1], [1]]) + in_top_2_op, update_op = top_2_fn(probabilities, targets) + with self.test_session(): + # initializes internal accuracy vars + tf.local_variables_initializer().run() + # need to call in order to run the in_top_2_op internal operations because + # it is a streaming function + update_op.eval() + self.assertNear(0.5, in_top_2_op.eval(), 0.0001) + + def testTop3(self): + top_3_fn = eval_metrics._top_k_generator(3) + probabilities = tf.constant([[0.1, 0.2, 0.6, 0.3, 0.5, 0.5], + [0.1, 0.4, 0.7, 0.3, 0.5, 0.2], + [0.1, 0.3, 0.8, 0.7, 0.4, 0.9], + [0.9, 0.8, 0.1, 0.8, 0.2, 0.7], + [0.3, 0.6, 0.9, 0.4, 0.8, 0.6]]) + targets = tf.constant([3, 0, 2, 5, 1]) + in_top_3_op, update_op = top_3_fn(probabilities, targets) + with self.test_session(): + # initializes internal accuracy vars + tf.local_variables_initializer().run() + # need to call in order to run the in_top_3_op internal operations because + # it is a streaming function + update_op.eval() + self.assertNear(0.4, in_top_3_op.eval(), 0.0001) + + def testAccuracy(self): + predictions = tf.constant([0, 1, 3, 6, 5, 2, 7, 6, 4, 9]) + targets = tf.constant([0, 1, 4, 6, 5, 1, 7, 5, 4, 8]) + accuracy_op, update_op = eval_metrics._accuracy(predictions, targets) + with self.test_session(): + tf.local_variables_initializer().run() + # need to call in order to run the accuracy_op internal operations because + # it is a streaming function + update_op.eval() + self.assertNear(0.6, accuracy_op.eval(), 0.0001) + + def testR2(self): + probabilities = tf.constant([1.2, 3.9, 2.1, 0.9, 2.2, 0.1, 6.0, 4.0, 0.9]) + targets = tf.constant([1.0, 4.3, 2.6, 0.5, 1.1, 0.7, 5.1, 3.4, 1.8]) + r2_op, update_op = eval_metrics._r2(probabilities, targets) + with self.test_session(): + # initializes internal accuracy vars + tf.local_variables_initializer().run() + # need to call in order to run the r2_op internal operations because + # it is a streaming function + update_op.eval() + self.assertNear(-19.7729, r2_op.eval(), 0.0001) + + +if __name__ == '__main__': + googletest.main() |