aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-09-26 20:27:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 20:33:41 -0700
commitde2bcdc7ad149419e270e1443b63581163d75d5d (patch)
tree203bbff75133c9c2ff26b2950a669365cba2ee56 /tensorflow/python/estimator
parent0d5c68e30f4637329fa233df506d7b97802a5e9b (diff)
Add Mirrored distribution strategy support for new metrics with Keras and Estimator
Add support for stateful metrics in model to estimator PiperOrigin-RevId: 214714322
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/keras.py39
-rw-r--r--tensorflow/python/estimator/keras_test.py28
2 files changed, 42 insertions, 25 deletions
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 6b2765be82..7546771ed3 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import os
import re
+import six
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator as estimator_lib
@@ -31,6 +32,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import metrics
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers
from tensorflow.python.ops import check_ops
@@ -214,25 +216,40 @@ def _convert_keras_metrics_to_estimator(model):
if not getattr(model, 'metrics', None):
return None
- # TODO(psv/fchollet): support stateful metrics
eval_metric_ops = {}
+
+ def get_metric_name(metric):
+ if isinstance(metric, metrics.Metric):
+ return metric.name
+ if callable(metric):
+ return metric.__name__
+ assert isinstance(metric, six.string_types)
+ return metric
+
# When each metric maps to an output
if isinstance(model.metrics, dict):
for i, output_name in enumerate(model.metrics.keys()):
- metric_name = model.metrics[output_name]
- if callable(metric_name):
- metric_name = metric_name.__name__
+ # `metric` is the user given metric value in `compile`. This can be
+ # metric name (`acc`), metric function (binary_accuracy) or a metric
+ # object (BinaryAccuracy()).
+ metric = model.metrics[output_name]
+ metric_name = get_metric_name(metric)
# When some outputs use the same metric
if list(model.metrics.values()).count(metric_name) > 1:
metric_name += '_' + output_name
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i - len(model.metrics)])
+ if isinstance(metric, metrics.Metric):
+ eval_metric_ops[metric_name] = metric
+ else:
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i - len(model.metrics)])
else:
- for i, metric_name in enumerate(model.metrics):
- if callable(metric_name):
- metric_name = metric_name.__name__
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i])
+ for i, metric in enumerate(model.metrics):
+ metric_name = get_metric_name(metric)
+ if isinstance(metric, metrics.Metric):
+ eval_metric_ops[metric_name] = metric
+ else:
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i])
return eval_metric_ops
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 3758243d7b..288f9b8906 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -257,7 +257,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -281,7 +281,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
my_hook = MyHook()
with self.cached_session():
@@ -306,7 +306,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
my_hook = MyHook()
with self.cached_session():
keras_model.fit(x_train, y_train, epochs=1)
@@ -328,7 +328,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -351,7 +351,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -370,7 +370,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
# Create state
@@ -662,7 +662,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
tf_config = json.dumps({
'cluster': {
@@ -687,7 +687,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
@@ -706,7 +706,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -736,7 +736,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
@@ -751,7 +751,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
@@ -765,7 +765,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
keras_model.train_on_batch(
np.random.random((10,) + _INPUT_SIZE),
@@ -776,7 +776,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=SGD(lr=0.0001, momentum=0.9),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
@@ -786,7 +786,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=optimizer,
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session() as sess:
keras_model_fn = keras_lib._create_keras_model_fn(keras_model)
global_step = training_util.create_global_step()