diff options
author | 2018-07-18 10:55:32 -0700 | |
---|---|---|
committer | 2018-07-18 11:00:37 -0700 | |
commit | 80ab99ec746520faa763a7bb171ee8850a597ec1 (patch) | |
tree | 98e0837412182445c696fab8c375517a9a4e3c0c | |
parent | fb70a5587395b1e68c08a4d396c63c5bd80fa1e1 (diff) |
Create new metrics class and add mean metric.
PiperOrigin-RevId: 205102847
-rw-r--r-- | tensorflow/python/keras/metrics.py | 381 | ||||
-rw-r--r-- | tensorflow/python/keras/metrics_test.py | 196 |
2 files changed, 541 insertions, 36 deletions
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index e03d7dfe93..72e15763cb 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -19,9 +19,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from abc import ABCMeta +from abc import abstractmethod import six +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.losses import binary_crossentropy from tensorflow.python.keras.losses import categorical_crossentropy from tensorflow.python.keras.losses import cosine_proximity @@ -37,11 +44,385 @@ from tensorflow.python.keras.losses import sparse_categorical_crossentropy from tensorflow.python.keras.losses import squared_hinge from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import confusion_matrix +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import tf_decorator from tensorflow.python.util.tf_export import tf_export +def update_state(update_state_fn): + """Decorator to wrap metric `update_state()` with `defun()`, `add_update()`. + + Args: + update_state_fn: function that accumulates metric statistics. + + Returns: + If eager execution is enabled, returns None. + If graph execution is enabled, returns an update op. This op should be + executed to update the metric state with the given inputs. + """ + + def decorated(*args, **kwargs): + """Decorated function with `defun()` and `add_update()`.""" + + # Converting update_state_fn() into a graph function, so that + # we can return a single op that performs all of the variable updates. + # Assigning to a different method name to avoid reference cycle. + defuned_update_state_fn = function.defun(update_state_fn) + update_op = defuned_update_state_fn(*args, **kwargs) + if update_op is not None: # update_op will be None in eager execution. + metric_obj = args[0] + metric_obj.add_update(update_op, inputs=True) + return update_op + + return tf_decorator.make_decorator(update_state_fn, decorated) + + +def result(result_fn): + """Decorator to wrap metric `result()` function in `merge_call()`. + + Result computation is an idempotent operation that simply calculates the + metric value using the state variables. + + If metric state variables are distributed across towers/devices and + `result()` is requested from the context of one device - This function wraps + `result()` in a distribution strategy `merge_call()`. With this, + the metric state variables will be aggregated across devices. + + Args: + result_fn: function that computes the metric result. + + Returns: + The metric result tensor. + """ + + def decorated(*args): + """Decorated function with merge_call.""" + tower_context = distribute_lib.get_tower_context() + if tower_context is None: # if in cross tower context already + return result_fn() + + # TODO(psv): Test distribution of metrics using different distribution + # strategies. + + # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn + # with distribution object as the first parameter. We create a wrapper here + # so that the result function need not have that parameter. + def merge_fn_wrapper(distribution, merge_fn, *args): + # We will get `PerDevice` merge function. Taking the first one as all are + # identical copies of the function that we had passed below. + return distribution.unwrap(merge_fn)[0](*args) + + # Wrapping result in merge_call. merge_call is used when we want to leave + # tower mode and compute a value in cross tower mode. + return tower_context.merge_call(merge_fn_wrapper, result_fn, *args) + + return tf_decorator.make_decorator(result_fn, decorated) + + +def _safe_div(numerator, denominator): + """Divides two tensors element-wise, returning 0 if the denominator is <= 0. + + Args: + numerator: A `Tensor`. + denominator: A `Tensor`, with dtype matching `numerator`. + + Returns: + 0 if `denominator` <= 0, else `numerator` / `denominator` + """ + t = math_ops.truediv(numerator, denominator) + zero = array_ops.zeros_like(t, dtype=denominator.dtype) + condition = math_ops.greater(denominator, zero) + zero = math_ops.cast(zero, t.dtype) + return array_ops.where(condition, t, zero) + + +def _squeeze_or_expand_dimensions(y_pred, y_true, sample_weight): + """Squeeze or expand last dimension if needed. + + 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1 + (using `confusion_matrix.remove_squeezable_dimensions`). + 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 + from the new rank of `y_pred`. + If `sample_weight` is scalar, it is kept scalar. + + This will use static shape if available. Otherwise, it will add graph + operations, which could result in a performance hit. + + Args: + y_pred: Predicted values, a `Tensor` of arbitrary dimensions. + y_true: Optional label `Tensor` whose dimensions match `y_pred`. + sample_weight: Optional weight scalar or `Tensor` whose dimensions match + `y_pred`. + + Returns: + Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has + the last dimension squeezed, + `sample_weight` could be extended by one dimension. + """ + if y_true is not None: + # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1 + y_true, y_pred = confusion_matrix.remove_squeezable_dimensions( + y_true, y_pred) + y_pred.get_shape().assert_is_compatible_with(y_true.get_shape()) + + if sample_weight is None: + return y_pred, y_true, None + + sample_weight = ops.convert_to_tensor(sample_weight) + weights_shape = sample_weight.get_shape() + weights_rank = weights_shape.ndims + if weights_rank == 0: # If weights is scalar, do nothing. + return y_pred, y_true, sample_weight + + y_pred_shape = y_pred.get_shape() + y_pred_rank = y_pred_shape.ndims + if (y_pred_rank is not None) and (weights_rank is not None): + # Use static rank. + if weights_rank - y_pred_rank == 1: + sample_weight = array_ops.squeeze(sample_weight, [-1]) + elif y_pred_rank - weights_rank == 1: + sample_weight = array_ops.expand_dims(sample_weight, [-1]) + return y_pred, y_true, sample_weight + + # Use dynamic rank. + weights_rank_tensor = array_ops.rank(sample_weight) + rank_diff = weights_rank_tensor - array_ops.rank(y_pred) + maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1]) + + def _maybe_expand_weights(): + return control_flow_ops.cond( + math_ops.equal(rank_diff, + -1), lambda: array_ops.expand_dims(sample_weight, [-1]), + lambda: sample_weight) + + def _maybe_adjust_weights(): + return control_flow_ops.cond( + math_ops.equal(rank_diff, 1), maybe_squeeze_weights, + _maybe_expand_weights) + + # squeeze or expand last dim of `sample_weight` if its rank differs by 1 + # from the new rank of `y_pred`. + sample_weight = control_flow_ops.cond( + math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight, + _maybe_adjust_weights) + return y_pred, y_true, sample_weight + + +class Metric(Layer): + """Encapsulates metric logic and state. + + Usage with eager execution: + + ```python + m = SomeMetric(...) + for input in ...: + m.update_state(input) + print('Final result: ', m.result().numpy()) + ``` + + Usage with graph execution: + + ```python + m = SomeMetric(...) + init_op = tf.global_variables_initializer() # Initialize variables + with tf.Session() as sess: + sess.run(init_op) + for input in ...: + update_op = m.update_state(input) + sess.run(update_op) + print('Final result: ', sess.run(m.result())) + ``` + + To be implemented by subclasses: + * `__init__()`: All state variables should be created in this method by + calling `self.add_weight()` like: `self.var = self.add_weight(...)` + * `update_state()`: Has all updates to the state variables like: + self.var.assign_add(...). Please decorate the function with: + @update_state: Converts `update_state()` into a graph function, so that + we can return a single op that performs all of the variable updates and + adds the update op to the metric layer. + * `result()`: Computes and returns a value for the metric + from the state variables. Please decorate the function with: + @result: Wraps `result()` in a distribution strategy merge_call(). + + Example subclass implementation: + + ``` + class BinaryTruePositives(Metric): + def __init__(self, name='binary-true-positives', dtype=dtypes.float64): + super(BinaryTruePositives, self).__init__(name=name, dtype=dtype) + self.true_positives = self.add_weight( + 'true_positives', initializer=init_ops.zeros_initializer) + + @update_state + def update_state(self, y_true, y_pred, sample_weight=None): + y_true = math_ops.cast(y_true, dtypes.bool) + y_pred = math_ops.cast(y_pred, dtypes.bool) + y_pred, y_true, sample_weight = _squeeze_or_expand_dimensions( + y_pred, y_true, sample_weight) + + values = math_ops.logical_and( + math_ops.equal(y_true, True), math_ops.equal(y_pred, True)) + values = math_ops.cast(values, self._dtype) + if sample_weight is not None: + sample_weight = math_ops.cast(sample_weight, self._dtype) + values = math_ops.multiply(values, sample_weight) + state_ops.assign_add(self.true_positives, math_ops.reduce_sum(values)) + + @result + def result(self): + return array_ops.identity(self.true_positives) + ``` + """ + __metaclass__ = ABCMeta + + def __init__(self, name=None, dtype=dtypes.float64): + super(Metric, self).__init__(name=name, dtype=dtype) + self.stateful = True # All metric layers are stateful. + self.built = True + + def __call__(self, *args, **kwargs): + """Accumulates statistics and then computes metric result value. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, + passed on to `update_state()`. + + Returns: + The metric value tensor. + """ + update_op = self.update_state(*args, **kwargs) + with ops.control_dependencies([update_op]): + return self.result() + + def reset_states(self): + """Resets all of the metric state variables. + + This function is called between epochs/steps, + when a metric is evaluated during training. + """ + for v in self.variables: + K.set_value(v, 0) + + @abstractmethod + def update_state(self, *args, **kwargs): + """Accumulates statistics for the metric. + + Please decorate the function with: + @update_state: Converts `update_state()` into a graph function, so that + we can return a single op that performs all of the variable updates + This means: + a) Operations on the same resource are executed in textual order. + This should make it easier to do things like add the updated + value of a variable to another, for example. + b) You don't need to worry about collecting the update ops to execute. + All update ops added to the graph by this function will be executed. + As a result, code should generally work the same way with graph or + eager execution. + and adds the update op to the metric layer. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric. + """ + NotImplementedError('Must be implemented in subclasses.') + + @abstractmethod + def result(self): + """Computes and returns the metric value tensor. + + Result computation is an idempotent operation that simply calculates the + metric value using the state variables. + + Please decorate the function with: + @result: Wraps `result()` in a distribution strategy merge_call(). + """ + NotImplementedError('Must be implemented in subclasses.') + + ### For use by subclasses ### + def add_weight(self, + name, + shape=(), + aggregation=vs.VariableAggregation.SUM, + synchronization=vs.VariableSynchronization.ON_READ, + initializer=None): + """Adds state variable. Only for use by subclasses.""" + return super(Metric, self).add_weight( + name=name, + shape=shape, + dtype=self._dtype, + trainable=False, + initializer=initializer, + synchronization=synchronization, + aggregation=aggregation) + + ### End: For use by subclasses ### + + +class Mean(Metric): + """Computes the (weighted) mean of the given values. + + This metric creates two variables, `total` and `count` that are used to + compute the average of `values`. This average is ultimately returned as `mean` + which is an idempotent operation that simply divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + """ + + def __init__(self, name='mean', dtype=dtypes.float64): + super(Mean, self).__init__(name=name, dtype=dtype) + # Create new state variables + self.total = self.add_weight( + 'total', initializer=init_ops.zeros_initializer) + self.count = self.add_weight( + 'count', initializer=init_ops.zeros_initializer) + + @update_state + def update_state(self, values, sample_weight=None): + """Accumulates statistics for computing the mean. + + For example, if `values` is [1, 3, 5, 7] then the mean is 4. If + the `sample_weight` is specified as [1, 1, 0, 0] then the mean would be 2. + + Args: + values: Per-example value. + sample_weight: Optional weighting of each example. Defaults to 1. + """ + values = math_ops.cast(values, self._dtype) + if sample_weight is None: + num_values = math_ops.cast(array_ops.size(values), self._dtype) + else: + sample_weight = math_ops.cast(sample_weight, self._dtype) + + # Update dimensions of weights to match with values. + values, _, sample_weight = _squeeze_or_expand_dimensions( + values, None, sample_weight) + sample_weight = weights_broadcast_ops.broadcast_weights( + sample_weight, values) + num_values = math_ops.reduce_sum(sample_weight) + values = math_ops.multiply(values, sample_weight) + values = math_ops.reduce_sum(values) + + # Update state variables + state_ops.assign_add(self.total, values) + state_ops.assign_add(self.count, num_values) + + @result + def result(self): + return _safe_div(self.total, self.count) + + @tf_export('keras.metrics.binary_accuracy') def binary_accuracy(y_true, y_pred): return K.mean(math_ops.equal(y_true, math_ops.round(y_pred)), axis=-1) diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 15e793f5fc..6d8269f34d 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -18,67 +18,72 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np -from tensorflow.python import keras +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import layers +from tensorflow.python.keras import metrics +from tensorflow.python.keras.engine.training import Model +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_utils class KerasMetricsTest(test.TestCase): def test_metrics(self): with self.test_session(): - y_a = keras.backend.variable(np.random.random((6, 7))) - y_b = keras.backend.variable(np.random.random((6, 7))) - for metric in [keras.metrics.binary_accuracy, - keras.metrics.categorical_accuracy]: + y_a = K.variable(np.random.random((6, 7))) + y_b = K.variable(np.random.random((6, 7))) + for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]: output = metric(y_a, y_b) - self.assertEqual(keras.backend.eval(output).shape, (6,)) + self.assertEqual(K.eval(output).shape, (6,)) def test_sparse_categorical_accuracy(self): with self.test_session(): - metric = keras.metrics.sparse_categorical_accuracy - y_a = keras.backend.variable(np.random.randint(0, 7, (6,))) - y_b = keras.backend.variable(np.random.random((6, 7))) - self.assertEqual(keras.backend.eval(metric(y_a, y_b)).shape, (6,)) + metric = metrics.sparse_categorical_accuracy + y_a = K.variable(np.random.randint(0, 7, (6,))) + y_b = K.variable(np.random.random((6, 7))) + self.assertEqual(K.eval(metric(y_a, y_b)).shape, (6,)) def test_sparse_top_k_categorical_accuracy(self): with self.test_session(): - y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1], - [0.1, 0.2, 0.7]])) - y_true = keras.backend.variable(np.array([[1], [0]])) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) + y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([[1], [0]])) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) self.assertEqual(result, 1) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) self.assertEqual(result, 0.5) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) def test_top_k_categorical_accuracy(self): with self.test_session(): - y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1], - [0.1, 0.2, 0.7]])) - y_true = keras.backend.variable(np.array([[0, 1, 0], [1, 0, 0]])) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)) + y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]])) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)) self.assertEqual(result, 1) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=2)) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=2)) self.assertEqual(result, 0.5) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) def test_stateful_metrics(self): with self.test_session(): np.random.seed(1334) - class BinaryTruePositives(keras.layers.Layer): + class BinaryTruePositives(layers.Layer): """Stateful Metric to count the total true positives over all batches. Assumes predictions and targets of shape `(samples, 1)`. @@ -91,11 +96,11 @@ class KerasMetricsTest(test.TestCase): def __init__(self, name='true_positives', **kwargs): super(BinaryTruePositives, self).__init__(name=name, **kwargs) - self.true_positives = keras.backend.variable(value=0, dtype='int32') + self.true_positives = K.variable(value=0, dtype='int32') self.stateful = True def reset_states(self): - keras.backend.set_value(self.true_positives, 0) + K.set_value(self.true_positives, 0) def __call__(self, y_true, y_pred): """Computes the number of true positives in a batch. @@ -120,14 +125,14 @@ class KerasMetricsTest(test.TestCase): return current_true_pos + true_pos metric_fn = BinaryTruePositives() - config = keras.metrics.serialize(metric_fn) - metric_fn = keras.metrics.deserialize( + config = metrics.serialize(metric_fn) + metric_fn = metrics.deserialize( config, custom_objects={'BinaryTruePositives': BinaryTruePositives}) # Test on simple model - inputs = keras.Input(shape=(2,)) - outputs = keras.layers.Dense(1, activation='sigmoid')(inputs) - model = keras.Model(inputs, outputs) + inputs = layers.Input(shape=(2,)) + outputs = layers.Dense(1, activation='sigmoid')(inputs) + model = Model(inputs, outputs) model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['acc', metric_fn]) @@ -184,6 +189,125 @@ class KerasMetricsTest(test.TestCase): self.assertAllClose( val_outs[2], history.history['val_true_positives'][-1], atol=1e-5) + @test_util.run_in_graph_and_eager_modes + def test_mean(self): + m = metrics.Mean(name='my_mean') + + # check config + self.assertEqual(m.name, 'my_mean') + self.assertTrue(m.stateful) + self.assertEqual(m.dtype, dtypes.float64) + self.assertEqual(len(m.variables), 2) + self.evaluate(variables.global_variables_initializer()) + + # check initial state + self.assertEqual(self.evaluate(m.total), 0) + self.assertEqual(self.evaluate(m.count), 0) + + # check __call__() + self.assertEqual(self.evaluate(m(100)), 100) + self.assertEqual(self.evaluate(m.total), 100) + self.assertEqual(self.evaluate(m.count), 1) + + # check update_state() and result() + state accumulation + tensor input + update_op = m.update_state(ops.convert_n_to_tensor([1, 5])) + self.evaluate(update_op) + self.assertEqual(self.evaluate(m.result()), 106 / 3) + self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5 + self.assertEqual(self.evaluate(m.count), 3) + + # check reset_states() + m.reset_states() + self.assertEqual(self.evaluate(m.total), 0) + self.assertEqual(self.evaluate(m.count), 0) + + @test_util.run_in_graph_and_eager_modes + def test_mean_with_sample_weight(self): + m = metrics.Mean() + self.evaluate(variables.global_variables_initializer()) + + # check scalar weight + result_t = m(100, sample_weight=0.5) + self.assertEqual(self.evaluate(result_t), 50 / 0.5) + self.assertEqual(self.evaluate(m.total), 50) + self.assertEqual(self.evaluate(m.count), 0.5) + + # check weights not scalar and weights rank matches values rank + result_t = m([1, 5], sample_weight=[1, 0.2]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 52 / 1.7, 2) + self.assertAlmostEqual(self.evaluate(m.total), 52, 2) # 50 + 1 + 5 * 0.2 + self.assertAlmostEqual(self.evaluate(m.count), 1.7, 2) # 0.5 + 1.2 + + # check weights broadcast + result_t = m([1, 2], sample_weight=0.5) + self.assertAlmostEqual(self.evaluate(result_t), 53.5 / 2.7, 2) + self.assertAlmostEqual(self.evaluate(m.total), 53.5, 2) # 52 + 0.5 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 2.7, 2) # 1.7 + 0.5 + 0.5 + + # check weights squeeze + result_t = m([1, 5], sample_weight=[[1], [0.2]]) + self.assertAlmostEqual(self.evaluate(result_t), 55.5 / 3.9, 2) + self.assertAlmostEqual(self.evaluate(m.total), 55.5, 2) # 53.5 + 1 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 3.9, 2) # 2.7 + 1.2 + + # check weights expand + result_t = m([[1], [5]], sample_weight=[1, 0.2]) + self.assertAlmostEqual(self.evaluate(result_t), 57.5 / 5.1, 2) + self.assertAlmostEqual(self.evaluate(m.total), 57.5, 2) # 55.5 + 1 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 5.1, 2) # 3.9 + 1.2 + + def test_mean_graph_with_placeholder(self): + with context.graph_mode(), self.test_session() as sess: + m = metrics.Mean() + v = array_ops.placeholder(dtypes.float32) + w = array_ops.placeholder(dtypes.float32) + sess.run(variables.global_variables_initializer()) + + # check __call__() + result_t = m(v, sample_weight=w) + result = sess.run(result_t, feed_dict=({v: 100, w: 0.5})) + self.assertEqual(sess.run(m.total), 50) + self.assertEqual(sess.run(m.count), 0.5) + self.assertEqual(result, 50 / 0.5) + + # check update_state() and result() + result = sess.run(result_t, feed_dict=({v: [1, 5], w: [1, 0.2]})) + self.assertAlmostEqual(sess.run(m.total), 52, 2) # 50 + 1 + 5 * 0.2 + self.assertAlmostEqual(sess.run(m.count), 1.7, 2) # 0.5 + 1.2 + self.assertAlmostEqual(result, 52 / 1.7, 2) + + @test_util.run_in_graph_and_eager_modes + def test_save_restore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + m = metrics.Mean() + checkpoint = checkpointable_utils.Checkpoint(mean=m) + self.evaluate(variables.global_variables_initializer()) + + # update state + self.evaluate(m(100.)) + self.evaluate(m(200.)) + + # save checkpoint and then add an update + save_path = checkpoint.save(checkpoint_prefix) + self.evaluate(m(1000.)) + + # restore to the same checkpoint mean object + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.evaluate(m(300.)) + self.assertEqual(200., self.evaluate(m.result())) + + # restore to a different checkpoint mean object + restore_mean = metrics.Mean() + restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + status = restore_checkpoint.restore(save_path) + restore_update = restore_mean(300.) + status.assert_consumed().run_restore_ops() + self.evaluate(restore_update) + self.assertEqual(200., self.evaluate(restore_mean.result())) + self.assertEqual(3, self.evaluate(restore_mean.count)) + if __name__ == '__main__': test.main() |