aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-12 12:56:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-12 13:03:07 -0800
commitdc60657830f2ab7bf2ba69a7270b1964b9bcbda2 (patch)
tree89b5022430ad89d2f8f4e1b606e5fa5e06b22e52
parentdb979cdd0a5da682f48cbe804cfb8ba8b9835e1a (diff)
Squeezing labels before passing them into the top k eval metric for TensorForest. Adding eval_metrics tests.
Change: 141799551
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD13
-rw-r--r--tensorflow/contrib/tensor_forest/client/eval_metrics.py70
-rw-r--r--tensorflow/contrib/tensor_forest/client/eval_metrics_test.py86
3 files changed, 138 insertions, 31 deletions
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 6828ef2222..e5ed22fdd3 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -124,6 +124,19 @@ py_library(
srcs_version = "PY2AND3",
)
+py_test(
+ name = "eval_metrics_test",
+ size = "small",
+ srcs = ["client/eval_metrics_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":eval_metrics",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
py_library(
name = "client_lib",
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
index c376e4557f..03ceb6f638 100644
--- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
@@ -30,8 +30,10 @@ INFERENCE_PRED_NAME = 'predictions'
def _top_k_generator(k):
def _top_k(probabilities, targets):
- return metric_ops.streaming_mean(nn.in_top_k(probabilities,
- math_ops.to_int32(targets), k))
+ targets = math_ops.to_int32(targets)
+ if targets.get_shape().ndims > 1:
+ targets = array_ops.squeeze(targets, squeeze_dims=[1])
+ return metric_ops.streaming_mean(nn.in_top_k(probabilities, targets, k))
return _top_k
@@ -45,8 +47,8 @@ def _r2(probabilities, targets, weights=None):
targets = math_ops.to_float(targets)
y_mean = math_ops.reduce_mean(targets, 0)
squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0)
- squares_residuals = math_ops.reduce_sum(math_ops.square(
- targets - probabilities), 0)
+ squares_residuals = math_ops.reduce_sum(
+ math_ops.square(targets - probabilities), 0)
score = 1 - math_ops.reduce_sum(squares_residuals / squares_total)
return metric_ops.streaming_mean(score, weights=weights)
@@ -57,16 +59,19 @@ def _squeeze_and_onehot(targets, depth):
def _sigmoid_entropy(probabilities, targets, weights=None):
- return metric_ops.streaming_mean(losses.sigmoid_cross_entropy(
- probabilities, _squeeze_and_onehot(targets,
- array_ops.shape(probabilities)[1])),
- weights=weights)
+ return metric_ops.streaming_mean(
+ losses.sigmoid_cross_entropy(probabilities,
+ _squeeze_and_onehot(
+ targets,
+ array_ops.shape(probabilities)[1])),
+ weights=weights)
def _softmax_entropy(probabilities, targets, weights=None):
- return metric_ops.streaming_mean(losses.sparse_softmax_cross_entropy(
- probabilities, math_ops.to_int32(targets)),
- weights=weights)
+ return metric_ops.streaming_mean(
+ losses.sparse_softmax_cross_entropy(probabilities,
+ math_ops.to_int32(targets)),
+ weights=weights)
def _predictions(predictions, unused_targets, **unused_kwargs):
@@ -89,26 +94,29 @@ def _recall(predictions, targets, weights=None):
return metric_ops.streaming_recall(predictions, targets, weights=weights)
-_EVAL_METRICS = {'sigmoid_entropy': _sigmoid_entropy,
- 'softmax_entropy': _softmax_entropy,
- 'accuracy': _accuracy,
- 'r2': _r2,
- 'predictions': _predictions,
- 'top_5': _top_k_generator(5),
- 'classification_log_loss': _class_log_loss,
- 'precision': _precision,
- 'recall': _recall}
-
-
-_PREDICTION_KEYS = {'sigmoid_entropy': INFERENCE_PROB_NAME,
- 'softmax_entropy': INFERENCE_PROB_NAME,
- 'accuracy': INFERENCE_PRED_NAME,
- 'r2': INFERENCE_PROB_NAME,
- 'predictions': INFERENCE_PRED_NAME,
- 'top_5': INFERENCE_PROB_NAME,
- 'classification_log_loss': INFERENCE_PROB_NAME,
- 'precision': INFERENCE_PRED_NAME,
- 'recall': INFERENCE_PRED_NAME}
+_EVAL_METRICS = {
+ 'sigmoid_entropy': _sigmoid_entropy,
+ 'softmax_entropy': _softmax_entropy,
+ 'accuracy': _accuracy,
+ 'r2': _r2,
+ 'predictions': _predictions,
+ 'top_5': _top_k_generator(5),
+ 'classification_log_loss': _class_log_loss,
+ 'precision': _precision,
+ 'recall': _recall
+}
+
+_PREDICTION_KEYS = {
+ 'sigmoid_entropy': INFERENCE_PROB_NAME,
+ 'softmax_entropy': INFERENCE_PROB_NAME,
+ 'accuracy': INFERENCE_PRED_NAME,
+ 'r2': INFERENCE_PROB_NAME,
+ 'predictions': INFERENCE_PRED_NAME,
+ 'top_5': INFERENCE_PROB_NAME,
+ 'classification_log_loss': INFERENCE_PROB_NAME,
+ 'precision': INFERENCE_PRED_NAME,
+ 'recall': INFERENCE_PRED_NAME
+}
def get_metric(metric_name):
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py b/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
new file mode 100644
index 0000000000..be3ef1a822
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
@@ -0,0 +1,86 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.tensor_forest.client.eval_metrics."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.tensor_forest.client import eval_metrics
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class EvalMetricsTest(test_util.TensorFlowTestCase):
+
+ def testTop2(self):
+ top_2_fn = eval_metrics._top_k_generator(2)
+ probabilities = tf.constant([[0.1, 0.2, 0.3], [0.4, 0.7, 0.5],
+ [0.9, 0.8, 0.2], [0.6, 0.4, 0.8]])
+ targets = tf.constant([[0], [2], [1], [1]])
+ in_top_2_op, update_op = top_2_fn(probabilities, targets)
+ with self.test_session():
+ # initializes internal accuracy vars
+ tf.local_variables_initializer().run()
+ # need to call in order to run the in_top_2_op internal operations because
+ # it is a streaming function
+ update_op.eval()
+ self.assertNear(0.5, in_top_2_op.eval(), 0.0001)
+
+ def testTop3(self):
+ top_3_fn = eval_metrics._top_k_generator(3)
+ probabilities = tf.constant([[0.1, 0.2, 0.6, 0.3, 0.5, 0.5],
+ [0.1, 0.4, 0.7, 0.3, 0.5, 0.2],
+ [0.1, 0.3, 0.8, 0.7, 0.4, 0.9],
+ [0.9, 0.8, 0.1, 0.8, 0.2, 0.7],
+ [0.3, 0.6, 0.9, 0.4, 0.8, 0.6]])
+ targets = tf.constant([3, 0, 2, 5, 1])
+ in_top_3_op, update_op = top_3_fn(probabilities, targets)
+ with self.test_session():
+ # initializes internal accuracy vars
+ tf.local_variables_initializer().run()
+ # need to call in order to run the in_top_3_op internal operations because
+ # it is a streaming function
+ update_op.eval()
+ self.assertNear(0.4, in_top_3_op.eval(), 0.0001)
+
+ def testAccuracy(self):
+ predictions = tf.constant([0, 1, 3, 6, 5, 2, 7, 6, 4, 9])
+ targets = tf.constant([0, 1, 4, 6, 5, 1, 7, 5, 4, 8])
+ accuracy_op, update_op = eval_metrics._accuracy(predictions, targets)
+ with self.test_session():
+ tf.local_variables_initializer().run()
+ # need to call in order to run the accuracy_op internal operations because
+ # it is a streaming function
+ update_op.eval()
+ self.assertNear(0.6, accuracy_op.eval(), 0.0001)
+
+ def testR2(self):
+ probabilities = tf.constant([1.2, 3.9, 2.1, 0.9, 2.2, 0.1, 6.0, 4.0, 0.9])
+ targets = tf.constant([1.0, 4.3, 2.6, 0.5, 1.1, 0.7, 5.1, 3.4, 1.8])
+ r2_op, update_op = eval_metrics._r2(probabilities, targets)
+ with self.test_session():
+ # initializes internal accuracy vars
+ tf.local_variables_initializer().run()
+ # need to call in order to run the r2_op internal operations because
+ # it is a streaming function
+ update_op.eval()
+ self.assertNear(-19.7729, r2_op.eval(), 0.0001)
+
+
+if __name__ == '__main__':
+ googletest.main()