1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Extenders of tf.estimator.Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.util import tf_inspect
_VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config'])
def add_metrics(estimator, metric_fn):
"""Creates a new ${tf.estimator.Estimator} which has given metrics.
Example:
```python
def my_auc(labels, predictions):
return {'auc': tf.metrics.auc(labels, predictions['logistic'])}
estimator = tf.estimator.DNNClassifier(...)
estimator = tf.contrib.estimator.add_metrics(estimator, my_auc)
estimator.train(...)
estimator.evaluate(...)
```
Example usage of custom metric which uses features:
```python
def my_auc(features, labels, predictions):
return {'auc': tf.metrics.auc(
labels, predictions['logistic'], weights=features['weight'])}
estimator = tf.estimator.DNNClassifier(...)
estimator = tf.contrib.estimator.add_metrics(estimator, my_auc)
estimator.train(...)
estimator.evaluate(...)
```
Args:
estimator: A ${tf.estimator.Estimator} object.
metric_fn: A function which should obey the following signature:
- Args: can only have following four arguments in any order:
* predictions: Predictions `Tensor` or dict of `Tensor` created by given
`estimator`.
* features: Input `dict` of `Tensor` objects created by `input_fn` which
is given to `estimator.evaluate` as an argument.
* labels: Labels `Tensor` or dict of `Tensor` created by `input_fn`
which is given to `estimator.evaluate` as an argument.
* config: config attribute of the `estimator`.
- Returns:
Dict of metric results keyed by name. Final metrics are a union of this
and `estimator's` existing metrics. If there is a name conflict between
this and `estimator`s existing metrics, this will override the existing
one. The values of the dict are the results of calling a metric
function, namely a `(metric_tensor, update_op)` tuple.
Returns:
A new ${tf.estimator.Estimator} which has a union of original metrics with
given ones.
"""
_verify_metric_fn_args(metric_fn)
def new_model_fn(features, labels, mode, config):
spec = estimator.model_fn(features, labels, mode, config)
if mode != model_fn_lib.ModeKeys.EVAL:
return spec
new_metrics = _call_metric_fn(metric_fn, features, labels, spec.predictions,
config)
all_metrics = spec.eval_metric_ops or {}
all_metrics.update(new_metrics)
return spec._replace(eval_metric_ops=all_metrics)
return estimator_lib.Estimator(
model_fn=new_model_fn,
model_dir=estimator.model_dir,
config=estimator.config)
def _verify_metric_fn_args(metric_fn):
args = set(estimator_util.fn_args(metric_fn))
if tf_inspect.ismethod(metric_fn):
if 'self' in args:
args.remove('self')
invalid_args = list(args - _VALID_METRIC_FN_ARGS)
if invalid_args:
raise ValueError('metric_fn (%s) has following not expected args: %s' %
(metric_fn, invalid_args))
def _call_metric_fn(metric_fn, features, labels, predictions, config):
"""Calls metric fn with proper arguments."""
metric_fn_args = estimator_util.fn_args(metric_fn)
kwargs = {}
if 'features' in metric_fn_args:
kwargs['features'] = features
if 'labels' in metric_fn_args:
kwargs['labels'] = labels
if 'predictions' in metric_fn_args:
kwargs['predictions'] = predictions
if 'config' in metric_fn_args:
kwargs['config'] = config
return metric_fn(**kwargs)
|