aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-07-18 10:55:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 11:00:37 -0700
commit80ab99ec746520faa763a7bb171ee8850a597ec1 (patch)
tree98e0837412182445c696fab8c375517a9a4e3c0c
parentfb70a5587395b1e68c08a4d396c63c5bd80fa1e1 (diff)
Create new metrics class and add mean metric.
PiperOrigin-RevId: 205102847
-rw-r--r--tensorflow/python/keras/metrics.py381
-rw-r--r--tensorflow/python/keras/metrics_test.py196
2 files changed, 541 insertions, 36 deletions
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index e03d7dfe93..72e15763cb 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -19,9 +19,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from abc import ABCMeta
+from abc import abstractmethod
import six
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.losses import binary_crossentropy
from tensorflow.python.keras.losses import categorical_crossentropy
from tensorflow.python.keras.losses import cosine_proximity
@@ -37,11 +44,385 @@ from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras.losses import squared_hinge
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import confusion_matrix
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
+def update_state(update_state_fn):
+ """Decorator to wrap metric `update_state()` with `defun()`, `add_update()`.
+
+ Args:
+ update_state_fn: function that accumulates metric statistics.
+
+ Returns:
+ If eager execution is enabled, returns None.
+ If graph execution is enabled, returns an update op. This op should be
+ executed to update the metric state with the given inputs.
+ """
+
+ def decorated(*args, **kwargs):
+ """Decorated function with `defun()` and `add_update()`."""
+
+ # Converting update_state_fn() into a graph function, so that
+ # we can return a single op that performs all of the variable updates.
+ # Assigning to a different method name to avoid reference cycle.
+ defuned_update_state_fn = function.defun(update_state_fn)
+ update_op = defuned_update_state_fn(*args, **kwargs)
+ if update_op is not None: # update_op will be None in eager execution.
+ metric_obj = args[0]
+ metric_obj.add_update(update_op, inputs=True)
+ return update_op
+
+ return tf_decorator.make_decorator(update_state_fn, decorated)
+
+
+def result(result_fn):
+ """Decorator to wrap metric `result()` function in `merge_call()`.
+
+ Result computation is an idempotent operation that simply calculates the
+ metric value using the state variables.
+
+ If metric state variables are distributed across towers/devices and
+ `result()` is requested from the context of one device - This function wraps
+ `result()` in a distribution strategy `merge_call()`. With this,
+ the metric state variables will be aggregated across devices.
+
+ Args:
+ result_fn: function that computes the metric result.
+
+ Returns:
+ The metric result tensor.
+ """
+
+ def decorated(*args):
+ """Decorated function with merge_call."""
+ tower_context = distribute_lib.get_tower_context()
+ if tower_context is None: # if in cross tower context already
+ return result_fn()
+
+ # TODO(psv): Test distribution of metrics using different distribution
+ # strategies.
+
+ # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
+ # with distribution object as the first parameter. We create a wrapper here
+ # so that the result function need not have that parameter.
+ def merge_fn_wrapper(distribution, merge_fn, *args):
+ # We will get `PerDevice` merge function. Taking the first one as all are
+ # identical copies of the function that we had passed below.
+ return distribution.unwrap(merge_fn)[0](*args)
+
+ # Wrapping result in merge_call. merge_call is used when we want to leave
+ # tower mode and compute a value in cross tower mode.
+ return tower_context.merge_call(merge_fn_wrapper, result_fn, *args)
+
+ return tf_decorator.make_decorator(result_fn, decorated)
+
+
+def _safe_div(numerator, denominator):
+ """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
+
+ Args:
+ numerator: A `Tensor`.
+ denominator: A `Tensor`, with dtype matching `numerator`.
+
+ Returns:
+ 0 if `denominator` <= 0, else `numerator` / `denominator`
+ """
+ t = math_ops.truediv(numerator, denominator)
+ zero = array_ops.zeros_like(t, dtype=denominator.dtype)
+ condition = math_ops.greater(denominator, zero)
+ zero = math_ops.cast(zero, t.dtype)
+ return array_ops.where(condition, t, zero)
+
+
+def _squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
+ """Squeeze or expand last dimension if needed.
+
+ 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
+ (using `confusion_matrix.remove_squeezable_dimensions`).
+ 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
+ from the new rank of `y_pred`.
+ If `sample_weight` is scalar, it is kept scalar.
+
+ This will use static shape if available. Otherwise, it will add graph
+ operations, which could result in a performance hit.
+
+ Args:
+ y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
+ y_true: Optional label `Tensor` whose dimensions match `y_pred`.
+ sample_weight: Optional weight scalar or `Tensor` whose dimensions match
+ `y_pred`.
+
+ Returns:
+ Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
+ the last dimension squeezed,
+ `sample_weight` could be extended by one dimension.
+ """
+ if y_true is not None:
+ # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
+ y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
+ y_true, y_pred)
+ y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
+
+ if sample_weight is None:
+ return y_pred, y_true, None
+
+ sample_weight = ops.convert_to_tensor(sample_weight)
+ weights_shape = sample_weight.get_shape()
+ weights_rank = weights_shape.ndims
+ if weights_rank == 0: # If weights is scalar, do nothing.
+ return y_pred, y_true, sample_weight
+
+ y_pred_shape = y_pred.get_shape()
+ y_pred_rank = y_pred_shape.ndims
+ if (y_pred_rank is not None) and (weights_rank is not None):
+ # Use static rank.
+ if weights_rank - y_pred_rank == 1:
+ sample_weight = array_ops.squeeze(sample_weight, [-1])
+ elif y_pred_rank - weights_rank == 1:
+ sample_weight = array_ops.expand_dims(sample_weight, [-1])
+ return y_pred, y_true, sample_weight
+
+ # Use dynamic rank.
+ weights_rank_tensor = array_ops.rank(sample_weight)
+ rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
+ maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
+
+ def _maybe_expand_weights():
+ return control_flow_ops.cond(
+ math_ops.equal(rank_diff,
+ -1), lambda: array_ops.expand_dims(sample_weight, [-1]),
+ lambda: sample_weight)
+
+ def _maybe_adjust_weights():
+ return control_flow_ops.cond(
+ math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
+ _maybe_expand_weights)
+
+ # squeeze or expand last dim of `sample_weight` if its rank differs by 1
+ # from the new rank of `y_pred`.
+ sample_weight = control_flow_ops.cond(
+ math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
+ _maybe_adjust_weights)
+ return y_pred, y_true, sample_weight
+
+
+class Metric(Layer):
+ """Encapsulates metric logic and state.
+
+ Usage with eager execution:
+
+ ```python
+ m = SomeMetric(...)
+ for input in ...:
+ m.update_state(input)
+ print('Final result: ', m.result().numpy())
+ ```
+
+ Usage with graph execution:
+
+ ```python
+ m = SomeMetric(...)
+ init_op = tf.global_variables_initializer() # Initialize variables
+ with tf.Session() as sess:
+ sess.run(init_op)
+ for input in ...:
+ update_op = m.update_state(input)
+ sess.run(update_op)
+ print('Final result: ', sess.run(m.result()))
+ ```
+
+ To be implemented by subclasses:
+ * `__init__()`: All state variables should be created in this method by
+ calling `self.add_weight()` like: `self.var = self.add_weight(...)`
+ * `update_state()`: Has all updates to the state variables like:
+ self.var.assign_add(...). Please decorate the function with:
+ @update_state: Converts `update_state()` into a graph function, so that
+ we can return a single op that performs all of the variable updates and
+ adds the update op to the metric layer.
+ * `result()`: Computes and returns a value for the metric
+ from the state variables. Please decorate the function with:
+ @result: Wraps `result()` in a distribution strategy merge_call().
+
+ Example subclass implementation:
+
+ ```
+ class BinaryTruePositives(Metric):
+ def __init__(self, name='binary-true-positives', dtype=dtypes.float64):
+ super(BinaryTruePositives, self).__init__(name=name, dtype=dtype)
+ self.true_positives = self.add_weight(
+ 'true_positives', initializer=init_ops.zeros_initializer)
+
+ @update_state
+ def update_state(self, y_true, y_pred, sample_weight=None):
+ y_true = math_ops.cast(y_true, dtypes.bool)
+ y_pred = math_ops.cast(y_pred, dtypes.bool)
+ y_pred, y_true, sample_weight = _squeeze_or_expand_dimensions(
+ y_pred, y_true, sample_weight)
+
+ values = math_ops.logical_and(
+ math_ops.equal(y_true, True), math_ops.equal(y_pred, True))
+ values = math_ops.cast(values, self._dtype)
+ if sample_weight is not None:
+ sample_weight = math_ops.cast(sample_weight, self._dtype)
+ values = math_ops.multiply(values, sample_weight)
+ state_ops.assign_add(self.true_positives, math_ops.reduce_sum(values))
+
+ @result
+ def result(self):
+ return array_ops.identity(self.true_positives)
+ ```
+ """
+ __metaclass__ = ABCMeta
+
+ def __init__(self, name=None, dtype=dtypes.float64):
+ super(Metric, self).__init__(name=name, dtype=dtype)
+ self.stateful = True # All metric layers are stateful.
+ self.built = True
+
+ def __call__(self, *args, **kwargs):
+ """Accumulates statistics and then computes metric result value.
+
+ Args:
+ *args:
+ **kwargs: A mini-batch of inputs to the Metric,
+ passed on to `update_state()`.
+
+ Returns:
+ The metric value tensor.
+ """
+ update_op = self.update_state(*args, **kwargs)
+ with ops.control_dependencies([update_op]):
+ return self.result()
+
+ def reset_states(self):
+ """Resets all of the metric state variables.
+
+ This function is called between epochs/steps,
+ when a metric is evaluated during training.
+ """
+ for v in self.variables:
+ K.set_value(v, 0)
+
+ @abstractmethod
+ def update_state(self, *args, **kwargs):
+ """Accumulates statistics for the metric.
+
+ Please decorate the function with:
+ @update_state: Converts `update_state()` into a graph function, so that
+ we can return a single op that performs all of the variable updates
+ This means:
+ a) Operations on the same resource are executed in textual order.
+ This should make it easier to do things like add the updated
+ value of a variable to another, for example.
+ b) You don't need to worry about collecting the update ops to execute.
+ All update ops added to the graph by this function will be executed.
+ As a result, code should generally work the same way with graph or
+ eager execution.
+ and adds the update op to the metric layer.
+
+ Args:
+ *args:
+ **kwargs: A mini-batch of inputs to the Metric.
+ """
+ NotImplementedError('Must be implemented in subclasses.')
+
+ @abstractmethod
+ def result(self):
+ """Computes and returns the metric value tensor.
+
+ Result computation is an idempotent operation that simply calculates the
+ metric value using the state variables.
+
+ Please decorate the function with:
+ @result: Wraps `result()` in a distribution strategy merge_call().
+ """
+ NotImplementedError('Must be implemented in subclasses.')
+
+ ### For use by subclasses ###
+ def add_weight(self,
+ name,
+ shape=(),
+ aggregation=vs.VariableAggregation.SUM,
+ synchronization=vs.VariableSynchronization.ON_READ,
+ initializer=None):
+ """Adds state variable. Only for use by subclasses."""
+ return super(Metric, self).add_weight(
+ name=name,
+ shape=shape,
+ dtype=self._dtype,
+ trainable=False,
+ initializer=initializer,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+ ### End: For use by subclasses ###
+
+
+class Mean(Metric):
+ """Computes the (weighted) mean of the given values.
+
+ This metric creates two variables, `total` and `count` that are used to
+ compute the average of `values`. This average is ultimately returned as `mean`
+ which is an idempotent operation that simply divides `total` by `count`.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+ """
+
+ def __init__(self, name='mean', dtype=dtypes.float64):
+ super(Mean, self).__init__(name=name, dtype=dtype)
+ # Create new state variables
+ self.total = self.add_weight(
+ 'total', initializer=init_ops.zeros_initializer)
+ self.count = self.add_weight(
+ 'count', initializer=init_ops.zeros_initializer)
+
+ @update_state
+ def update_state(self, values, sample_weight=None):
+ """Accumulates statistics for computing the mean.
+
+ For example, if `values` is [1, 3, 5, 7] then the mean is 4. If
+ the `sample_weight` is specified as [1, 1, 0, 0] then the mean would be 2.
+
+ Args:
+ values: Per-example value.
+ sample_weight: Optional weighting of each example. Defaults to 1.
+ """
+ values = math_ops.cast(values, self._dtype)
+ if sample_weight is None:
+ num_values = math_ops.cast(array_ops.size(values), self._dtype)
+ else:
+ sample_weight = math_ops.cast(sample_weight, self._dtype)
+
+ # Update dimensions of weights to match with values.
+ values, _, sample_weight = _squeeze_or_expand_dimensions(
+ values, None, sample_weight)
+ sample_weight = weights_broadcast_ops.broadcast_weights(
+ sample_weight, values)
+ num_values = math_ops.reduce_sum(sample_weight)
+ values = math_ops.multiply(values, sample_weight)
+ values = math_ops.reduce_sum(values)
+
+ # Update state variables
+ state_ops.assign_add(self.total, values)
+ state_ops.assign_add(self.count, num_values)
+
+ @result
+ def result(self):
+ return _safe_div(self.total, self.count)
+
+
@tf_export('keras.metrics.binary_accuracy')
def binary_accuracy(y_true, y_pred):
return K.mean(math_ops.equal(y_true, math_ops.round(y_pred)), axis=-1)
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 15e793f5fc..6d8269f34d 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -18,67 +18,72 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import numpy as np
-from tensorflow.python import keras
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import layers
+from tensorflow.python.keras import metrics
+from tensorflow.python.keras.engine.training import Model
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import util as checkpointable_utils
class KerasMetricsTest(test.TestCase):
def test_metrics(self):
with self.test_session():
- y_a = keras.backend.variable(np.random.random((6, 7)))
- y_b = keras.backend.variable(np.random.random((6, 7)))
- for metric in [keras.metrics.binary_accuracy,
- keras.metrics.categorical_accuracy]:
+ y_a = K.variable(np.random.random((6, 7)))
+ y_b = K.variable(np.random.random((6, 7)))
+ for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]:
output = metric(y_a, y_b)
- self.assertEqual(keras.backend.eval(output).shape, (6,))
+ self.assertEqual(K.eval(output).shape, (6,))
def test_sparse_categorical_accuracy(self):
with self.test_session():
- metric = keras.metrics.sparse_categorical_accuracy
- y_a = keras.backend.variable(np.random.randint(0, 7, (6,)))
- y_b = keras.backend.variable(np.random.random((6, 7)))
- self.assertEqual(keras.backend.eval(metric(y_a, y_b)).shape, (6,))
+ metric = metrics.sparse_categorical_accuracy
+ y_a = K.variable(np.random.randint(0, 7, (6,)))
+ y_b = K.variable(np.random.random((6, 7)))
+ self.assertEqual(K.eval(metric(y_a, y_b)).shape, (6,))
def test_sparse_top_k_categorical_accuracy(self):
with self.test_session():
- y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1],
- [0.1, 0.2, 0.7]]))
- y_true = keras.backend.variable(np.array([[1], [0]]))
- result = keras.backend.eval(
- keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
+ y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
+ y_true = K.variable(np.array([[1], [0]]))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
self.assertEqual(result, 1)
- result = keras.backend.eval(
- keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
self.assertEqual(result, 0.5)
- result = keras.backend.eval(
- keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
def test_top_k_categorical_accuracy(self):
with self.test_session():
- y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1],
- [0.1, 0.2, 0.7]]))
- y_true = keras.backend.variable(np.array([[0, 1, 0], [1, 0, 0]]))
- result = keras.backend.eval(
- keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
+ y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
+ y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]]))
+ result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
self.assertEqual(result, 1)
- result = keras.backend.eval(
- keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=2))
+ result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=2))
self.assertEqual(result, 0.5)
- result = keras.backend.eval(
- keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1))
+ result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
def test_stateful_metrics(self):
with self.test_session():
np.random.seed(1334)
- class BinaryTruePositives(keras.layers.Layer):
+ class BinaryTruePositives(layers.Layer):
"""Stateful Metric to count the total true positives over all batches.
Assumes predictions and targets of shape `(samples, 1)`.
@@ -91,11 +96,11 @@ class KerasMetricsTest(test.TestCase):
def __init__(self, name='true_positives', **kwargs):
super(BinaryTruePositives, self).__init__(name=name, **kwargs)
- self.true_positives = keras.backend.variable(value=0, dtype='int32')
+ self.true_positives = K.variable(value=0, dtype='int32')
self.stateful = True
def reset_states(self):
- keras.backend.set_value(self.true_positives, 0)
+ K.set_value(self.true_positives, 0)
def __call__(self, y_true, y_pred):
"""Computes the number of true positives in a batch.
@@ -120,14 +125,14 @@ class KerasMetricsTest(test.TestCase):
return current_true_pos + true_pos
metric_fn = BinaryTruePositives()
- config = keras.metrics.serialize(metric_fn)
- metric_fn = keras.metrics.deserialize(
+ config = metrics.serialize(metric_fn)
+ metric_fn = metrics.deserialize(
config, custom_objects={'BinaryTruePositives': BinaryTruePositives})
# Test on simple model
- inputs = keras.Input(shape=(2,))
- outputs = keras.layers.Dense(1, activation='sigmoid')(inputs)
- model = keras.Model(inputs, outputs)
+ inputs = layers.Input(shape=(2,))
+ outputs = layers.Dense(1, activation='sigmoid')(inputs)
+ model = Model(inputs, outputs)
model.compile(optimizer='sgd',
loss='binary_crossentropy',
metrics=['acc', metric_fn])
@@ -184,6 +189,125 @@ class KerasMetricsTest(test.TestCase):
self.assertAllClose(
val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)
+ @test_util.run_in_graph_and_eager_modes
+ def test_mean(self):
+ m = metrics.Mean(name='my_mean')
+
+ # check config
+ self.assertEqual(m.name, 'my_mean')
+ self.assertTrue(m.stateful)
+ self.assertEqual(m.dtype, dtypes.float64)
+ self.assertEqual(len(m.variables), 2)
+ self.evaluate(variables.global_variables_initializer())
+
+ # check initial state
+ self.assertEqual(self.evaluate(m.total), 0)
+ self.assertEqual(self.evaluate(m.count), 0)
+
+ # check __call__()
+ self.assertEqual(self.evaluate(m(100)), 100)
+ self.assertEqual(self.evaluate(m.total), 100)
+ self.assertEqual(self.evaluate(m.count), 1)
+
+ # check update_state() and result() + state accumulation + tensor input
+ update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
+ self.evaluate(update_op)
+ self.assertEqual(self.evaluate(m.result()), 106 / 3)
+ self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5
+ self.assertEqual(self.evaluate(m.count), 3)
+
+ # check reset_states()
+ m.reset_states()
+ self.assertEqual(self.evaluate(m.total), 0)
+ self.assertEqual(self.evaluate(m.count), 0)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_mean_with_sample_weight(self):
+ m = metrics.Mean()
+ self.evaluate(variables.global_variables_initializer())
+
+ # check scalar weight
+ result_t = m(100, sample_weight=0.5)
+ self.assertEqual(self.evaluate(result_t), 50 / 0.5)
+ self.assertEqual(self.evaluate(m.total), 50)
+ self.assertEqual(self.evaluate(m.count), 0.5)
+
+ # check weights not scalar and weights rank matches values rank
+ result_t = m([1, 5], sample_weight=[1, 0.2])
+ result = self.evaluate(result_t)
+ self.assertAlmostEqual(result, 52 / 1.7, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 52, 2) # 50 + 1 + 5 * 0.2
+ self.assertAlmostEqual(self.evaluate(m.count), 1.7, 2) # 0.5 + 1.2
+
+ # check weights broadcast
+ result_t = m([1, 2], sample_weight=0.5)
+ self.assertAlmostEqual(self.evaluate(result_t), 53.5 / 2.7, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 53.5, 2) # 52 + 0.5 + 1
+ self.assertAlmostEqual(self.evaluate(m.count), 2.7, 2) # 1.7 + 0.5 + 0.5
+
+ # check weights squeeze
+ result_t = m([1, 5], sample_weight=[[1], [0.2]])
+ self.assertAlmostEqual(self.evaluate(result_t), 55.5 / 3.9, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 55.5, 2) # 53.5 + 1 + 1
+ self.assertAlmostEqual(self.evaluate(m.count), 3.9, 2) # 2.7 + 1.2
+
+ # check weights expand
+ result_t = m([[1], [5]], sample_weight=[1, 0.2])
+ self.assertAlmostEqual(self.evaluate(result_t), 57.5 / 5.1, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 57.5, 2) # 55.5 + 1 + 1
+ self.assertAlmostEqual(self.evaluate(m.count), 5.1, 2) # 3.9 + 1.2
+
+ def test_mean_graph_with_placeholder(self):
+ with context.graph_mode(), self.test_session() as sess:
+ m = metrics.Mean()
+ v = array_ops.placeholder(dtypes.float32)
+ w = array_ops.placeholder(dtypes.float32)
+ sess.run(variables.global_variables_initializer())
+
+ # check __call__()
+ result_t = m(v, sample_weight=w)
+ result = sess.run(result_t, feed_dict=({v: 100, w: 0.5}))
+ self.assertEqual(sess.run(m.total), 50)
+ self.assertEqual(sess.run(m.count), 0.5)
+ self.assertEqual(result, 50 / 0.5)
+
+ # check update_state() and result()
+ result = sess.run(result_t, feed_dict=({v: [1, 5], w: [1, 0.2]}))
+ self.assertAlmostEqual(sess.run(m.total), 52, 2) # 50 + 1 + 5 * 0.2
+ self.assertAlmostEqual(sess.run(m.count), 1.7, 2) # 0.5 + 1.2
+ self.assertAlmostEqual(result, 52 / 1.7, 2)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_save_restore(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+ m = metrics.Mean()
+ checkpoint = checkpointable_utils.Checkpoint(mean=m)
+ self.evaluate(variables.global_variables_initializer())
+
+ # update state
+ self.evaluate(m(100.))
+ self.evaluate(m(200.))
+
+ # save checkpoint and then add an update
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.evaluate(m(1000.))
+
+ # restore to the same checkpoint mean object
+ checkpoint.restore(save_path).assert_consumed().run_restore_ops()
+ self.evaluate(m(300.))
+ self.assertEqual(200., self.evaluate(m.result()))
+
+ # restore to a different checkpoint mean object
+ restore_mean = metrics.Mean()
+ restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean)
+ status = restore_checkpoint.restore(save_path)
+ restore_update = restore_mean(300.)
+ status.assert_consumed().run_restore_ops()
+ self.evaluate(restore_update)
+ self.assertEqual(200., self.evaluate(restore_mean.result()))
+ self.assertEqual(3, self.evaluate(restore_mean.count))
+
if __name__ == '__main__':
test.main()