aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-20 17:46:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-20 17:50:08 -0700
commit62df65c7255e2a8878cd29f66fe80ff8952de157 (patch)
tree49dd9af4e43e3ab2b05ce32f9ec2ada921d3cb22
parent29c7b46585aabab6b1a1677324667c2d5720181c (diff)
Add dtype argument to Mean and Accuracy object-oriented metrics.
PiperOrigin-RevId: 172957714
-rw-r--r--tensorflow/contrib/eager/python/metrics_impl.py27
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py20
2 files changed, 36 insertions, 11 deletions
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index 2a624b218c..2139c2b4b9 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -198,13 +198,19 @@ class Mean(Metric):
# TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64?
# Or defaults to type of the input if it is tf.float32, else tf.float64?
- def build(self, values, weights=None):
- del values, weights # build() does not use call's arguments
+ def __init__(self, name=None, dtype=dtypes.float64):
+ super(Mean, self).__init__(name=name)
+ self.dtype = dtype
+
+ def build(self, *args, **kwargs):
+ # build() does not use call's arguments, by using *args, **kwargs
+ # we make it easier to inherit from Mean().
+ del args, kwargs
self.numer = self.add_variable(name="numer", shape=(),
- dtype=dtypes.float64,
+ dtype=self.dtype,
initializer=init_ops.zeros_initializer)
self.denom = self.add_variable(name="denom", shape=(),
- dtype=dtypes.float64,
+ dtype=self.dtype,
initializer=init_ops.zeros_initializer)
def call(self, values, weights=None):
@@ -219,13 +225,13 @@ class Mean(Metric):
"""
if weights is None:
self.denom.assign_add(
- math_ops.cast(array_ops.size(values), dtypes.float64))
+ math_ops.cast(array_ops.size(values), self.dtype))
values = math_ops.reduce_sum(values)
- self.numer.assign_add(math_ops.cast(values, dtypes.float64))
+ self.numer.assign_add(math_ops.cast(values, self.dtype))
else:
- weights = math_ops.cast(weights, dtypes.float64)
+ weights = math_ops.cast(weights, self.dtype)
self.denom.assign_add(math_ops.reduce_sum(weights))
- values = math_ops.cast(values, dtypes.float64) * weights
+ values = math_ops.cast(values, self.dtype) * weights
self.numer.assign_add(math_ops.reduce_sum(values))
def result(self):
@@ -235,9 +241,8 @@ class Mean(Metric):
class Accuracy(Mean):
"""Calculates how often `predictions` matches `labels`."""
- def build(self, labels, predictions, weights=None):
- del labels, predictions, weights
- super(Accuracy, self).build(None) # Arguments are unused
+ def __init__(self, name=None, dtype=dtypes.float64):
+ super(Accuracy, self).__init__(name=name, dtype=dtype)
def call(self, labels, predictions, weights=None):
"""Accumulate accuracy statistics.
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index bfb79cd72e..9743666c89 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -34,6 +34,8 @@ class MetricsTest(test.TestCase):
m(1000)
m([10000.0, 100000.0])
self.assertEqual(111111.0/6, m.result().numpy())
+ self.assertEqual(dtypes.float64, m.dtype)
+ self.assertEqual(dtypes.float64, m.result().dtype)
def testWeightedMean(self):
m = metrics.Mean()
@@ -41,6 +43,14 @@ class MetricsTest(test.TestCase):
m([500000, 5000, 500]) # weights of 1 each
self.assertNear(535521/4.5, m.result().numpy(), 0.001)
+ def testMeanDtype(self):
+ # Can override default dtype of float64.
+ m = metrics.Mean(dtype=dtypes.float32)
+ m([0, 2])
+ self.assertEqual(1, m.result().numpy())
+ self.assertEqual(dtypes.float32, m.dtype)
+ self.assertEqual(dtypes.float32, m.result().dtype)
+
def testAccuracy(self):
m = metrics.Accuracy()
m([0, 1, 2, 3], [0, 0, 0, 0]) # 1 correct
@@ -49,6 +59,8 @@ class MetricsTest(test.TestCase):
m([6], [6]) # 1 correct
m([7], [2]) # 0 correct
self.assertEqual(3.0/8, m.result().numpy())
+ self.assertEqual(dtypes.float64, m.dtype)
+ self.assertEqual(dtypes.float64, m.result().dtype)
def testWeightedAccuracy(self):
m = metrics.Accuracy()
@@ -60,6 +72,14 @@ class MetricsTest(test.TestCase):
m([7], [2]) # 0 correct, weight 1
self.assertEqual(2.5/5, m.result().numpy())
+ def testAccuracyDtype(self):
+ # Can override default dtype of float64.
+ m = metrics.Accuracy(dtype=dtypes.float32)
+ m([0, 0], [0, 1])
+ self.assertEqual(0.5, m.result().numpy())
+ self.assertEqual(dtypes.float32, m.dtype)
+ self.assertEqual(dtypes.float32, m.result().dtype)
+
def testTwoMeans(self):
# Verify two metrics with the same class and name don't
# accidentally share state.