diff options
Diffstat (limited to 'tensorflow/python/keras/metrics.py')
-rw-r--r-- | tensorflow/python/keras/metrics.py | 58 |
1 files changed, 46 insertions, 12 deletions
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index fd3c39cf2e..f4e8419eb0 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -71,6 +71,22 @@ def check_is_tensor_or_operation(x, name): name, x)) +def clone_metric(metric): + """Returns a clone of the metric if stateful, otherwise returns it as is.""" + if isinstance(metric, Metric): + return metric.__class__.from_config(metric.get_config()) + return metric + + +def clone_metrics(metrics): + """Clones the given metric list/dict.""" + if metrics is None: + return None + if isinstance(metrics, dict): + return {key: clone_metric(value) for key, value in metrics.items()} + return [clone_metric(metric) for metric in metrics] + + def update_state_wrapper(update_state_fn): """Decorator to wrap metric `update_state()` with `add_update()`. @@ -199,7 +215,6 @@ def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight): # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1 y_true, y_pred = confusion_matrix.remove_squeezable_dimensions( y_true, y_pred) - y_pred.get_shape().assert_is_compatible_with(y_true.get_shape()) if sample_weight is None: return y_pred, y_true, None @@ -342,19 +357,14 @@ class Metric(Layer): # weak reference. This is to remove reference cycle that is created here. # This is not an issue in python versions > 3. if context.executing_eagerly(): - update_state = weakmethod(obj.update_state) - else: - update_state = function.defun(obj.update_state) + obj.update_state = weakmethod(obj.update_state) obj.update_state = weakmethod( - types.MethodType(update_state_wrapper(update_state), obj)) + types.MethodType(update_state_wrapper(obj.update_state), obj)) result = weakmethod(obj.result) obj.result = weakmethod(types.MethodType(result_wrapper(result), obj)) else: - # Converting update_state_fn() into a graph function, so that - # we can return a single op that performs all of the variable updates. - defuned_update_state_fn = function.defun(obj.update_state) obj.update_state = types.MethodType( - update_state_wrapper(defuned_update_state_fn), obj) + update_state_wrapper(obj.update_state), obj) obj.result = types.MethodType(result_wrapper(obj.result), obj) return obj @@ -475,6 +485,9 @@ class Mean(Metric): Args: values: Per-example value. sample_weight: Optional weighting of each example. Defaults to 1. + + Returns: + Update op. """ values = math_ops.cast(values, self._dtype) if sample_weight is None: @@ -501,8 +514,9 @@ class Mean(Metric): values = math_ops.reduce_sum(values) # Update state variables - state_ops.assign_add(self.total, values) - state_ops.assign_add(self.count, num_values) + update_total_op = state_ops.assign_add(self.total, values) + update_count_op = state_ops.assign_add(self.count, num_values) + return control_flow_ops.group(update_total_op, update_count_op) def result(self): return safe_div(self.total, self.count) @@ -536,6 +550,9 @@ class MeanMetricWrapper(Mean): sample_weight: Optional weighting of each example. Defaults to 1. Can be a `Tensor` whose rank is either 0, or the same rank as `y_true`, and must be broadcastable to `y_true`. + + Returns: + Update op. """ y_true = math_ops.cast(y_true, self._dtype) y_pred = math_ops.cast(y_pred, self._dtype) @@ -543,7 +560,7 @@ class MeanMetricWrapper(Mean): y_pred, y_true, sample_weight) matches = self._fn(y_true, y_pred, **self._fn_kwargs) - super(MeanMetricWrapper, self).update_state( + return super(MeanMetricWrapper, self).update_state( matches, sample_weight=sample_weight) def get_config(self): @@ -600,6 +617,23 @@ class CategoricalAccuracy(MeanMetricWrapper): categorical_accuracy, name, dtype=dtype) +class SparseCategoricalAccuracy(MeanMetricWrapper): + """Calculates how often predictions matches integer labels. + + This metric creates two local variables, `total` and `count` that are used to + compute the frequency with which `y_pred` matches `y_true`. This frequency is + ultimately returned as `sparse categorical accuracy`: an idempotent operation + that simply divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + """ + + def __init__(self, name='sparse_categorical_accuracy', dtype=None): + super(SparseCategoricalAccuracy, self).__init__( + sparse_categorical_accuracy, name, dtype=dtype) + + @tf_export('keras.metrics.binary_accuracy') def binary_accuracy(y_true, y_pred, threshold=0.5): threshold = math_ops.cast(threshold, y_pred.dtype) |