aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/metrics_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/metrics_impl.py')
-rw-r--r--tensorflow/contrib/eager/python/metrics_impl.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index efa6ba0626..6efafccd6b 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -291,8 +291,6 @@ class Metric(checkpointable.CheckpointableBase):
class Mean(Metric):
"""Computes the (weighted) mean of the given values."""
- # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64?
- # Or defaults to type of the input if it is tf.float32, else tf.float64?
def __init__(self, name=None, dtype=dtypes.float64,
use_global_variables=False):
@@ -377,7 +375,7 @@ class Accuracy(Mean):
array_ops.shape(labels), array_ops.shape(predictions),
message="Shapes of labels and predictions are unequal")
matches = math_ops.equal(labels, predictions)
- matches = math_ops.cast(matches, dtypes.float64)
+ matches = math_ops.cast(matches, self.dtype)
super(Accuracy, self).call(matches, weights=weights)
if weights is None:
return labels, predictions
@@ -421,7 +419,7 @@ class CategoricalAccuracy(Mean):
labels = math_ops.argmax(labels, axis=-1)
predictions = math_ops.argmax(predictions, axis=-1)
matches = math_ops.equal(labels, predictions)
- matches = math_ops.cast(matches, dtypes.float64)
+ matches = math_ops.cast(matches, self.dtype)
super(CategoricalAccuracy, self).call(matches, weights=weights)
if weights is None:
return labels, predictions
@@ -472,7 +470,7 @@ class BinaryAccuracy(Mean):
predictions = ops.convert_to_tensor(predictions)
predictions = predictions > self.threshold
matches = math_ops.equal(labels, predictions)
- matches = math_ops.cast(matches, dtypes.float64)
+ matches = math_ops.cast(matches, self.dtype)
super(BinaryAccuracy, self).call(matches, weights=weights)
if weights is None:
return labels, predictions
@@ -520,7 +518,7 @@ class SparseAccuracy(Mean):
predictions = math_ops.argmax(predictions, axis=-1)
labels = math_ops.cast(labels, dtypes.int64)
matches = math_ops.equal(labels, predictions)
- matches = math_ops.cast(matches, dtypes.float64)
+ matches = math_ops.cast(matches, self.dtype)
super(SparseAccuracy, self).call(matches, weights=weights)
if weights is None:
return labels, predictions