aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/metrics.py')
-rw-r--r--tensorflow/python/keras/metrics.py58
1 files changed, 46 insertions, 12 deletions
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index fd3c39cf2e..f4e8419eb0 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -71,6 +71,22 @@ def check_is_tensor_or_operation(x, name):
name, x))
+def clone_metric(metric):
+ """Returns a clone of the metric if stateful, otherwise returns it as is."""
+ if isinstance(metric, Metric):
+ return metric.__class__.from_config(metric.get_config())
+ return metric
+
+
+def clone_metrics(metrics):
+ """Clones the given metric list/dict."""
+ if metrics is None:
+ return None
+ if isinstance(metrics, dict):
+ return {key: clone_metric(value) for key, value in metrics.items()}
+ return [clone_metric(metric) for metric in metrics]
+
+
def update_state_wrapper(update_state_fn):
"""Decorator to wrap metric `update_state()` with `add_update()`.
@@ -199,7 +215,6 @@ def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
# squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
y_true, y_pred)
- y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
if sample_weight is None:
return y_pred, y_true, None
@@ -342,19 +357,14 @@ class Metric(Layer):
# weak reference. This is to remove reference cycle that is created here.
# This is not an issue in python versions > 3.
if context.executing_eagerly():
- update_state = weakmethod(obj.update_state)
- else:
- update_state = function.defun(obj.update_state)
+ obj.update_state = weakmethod(obj.update_state)
obj.update_state = weakmethod(
- types.MethodType(update_state_wrapper(update_state), obj))
+ types.MethodType(update_state_wrapper(obj.update_state), obj))
result = weakmethod(obj.result)
obj.result = weakmethod(types.MethodType(result_wrapper(result), obj))
else:
- # Converting update_state_fn() into a graph function, so that
- # we can return a single op that performs all of the variable updates.
- defuned_update_state_fn = function.defun(obj.update_state)
obj.update_state = types.MethodType(
- update_state_wrapper(defuned_update_state_fn), obj)
+ update_state_wrapper(obj.update_state), obj)
obj.result = types.MethodType(result_wrapper(obj.result), obj)
return obj
@@ -475,6 +485,9 @@ class Mean(Metric):
Args:
values: Per-example value.
sample_weight: Optional weighting of each example. Defaults to 1.
+
+ Returns:
+ Update op.
"""
values = math_ops.cast(values, self._dtype)
if sample_weight is None:
@@ -501,8 +514,9 @@ class Mean(Metric):
values = math_ops.reduce_sum(values)
# Update state variables
- state_ops.assign_add(self.total, values)
- state_ops.assign_add(self.count, num_values)
+ update_total_op = state_ops.assign_add(self.total, values)
+ update_count_op = state_ops.assign_add(self.count, num_values)
+ return control_flow_ops.group(update_total_op, update_count_op)
def result(self):
return safe_div(self.total, self.count)
@@ -536,6 +550,9 @@ class MeanMetricWrapper(Mean):
sample_weight: Optional weighting of each example. Defaults to 1. Can be
a `Tensor` whose rank is either 0, or the same rank as `y_true`,
and must be broadcastable to `y_true`.
+
+ Returns:
+ Update op.
"""
y_true = math_ops.cast(y_true, self._dtype)
y_pred = math_ops.cast(y_pred, self._dtype)
@@ -543,7 +560,7 @@ class MeanMetricWrapper(Mean):
y_pred, y_true, sample_weight)
matches = self._fn(y_true, y_pred, **self._fn_kwargs)
- super(MeanMetricWrapper, self).update_state(
+ return super(MeanMetricWrapper, self).update_state(
matches, sample_weight=sample_weight)
def get_config(self):
@@ -600,6 +617,23 @@ class CategoricalAccuracy(MeanMetricWrapper):
categorical_accuracy, name, dtype=dtype)
+class SparseCategoricalAccuracy(MeanMetricWrapper):
+ """Calculates how often predictions matches integer labels.
+
+ This metric creates two local variables, `total` and `count` that are used to
+ compute the frequency with which `y_pred` matches `y_true`. This frequency is
+ ultimately returned as `sparse categorical accuracy`: an idempotent operation
+ that simply divides `total` by `count`.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+ """
+
+ def __init__(self, name='sparse_categorical_accuracy', dtype=None):
+ super(SparseCategoricalAccuracy, self).__init__(
+ sparse_categorical_accuracy, name, dtype=dtype)
+
+
@tf_export('keras.metrics.binary_accuracy')
def binary_accuracy(y_true, y_pred, threshold=0.5):
threshold = math_ops.cast(threshold, y_pred.dtype)