diff options
author | Pavithra Vijay <psv@google.com> | 2018-09-26 20:27:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 20:33:41 -0700 |
commit | de2bcdc7ad149419e270e1443b63581163d75d5d (patch) | |
tree | 203bbff75133c9c2ff26b2950a669365cba2ee56 /tensorflow/python/estimator | |
parent | 0d5c68e30f4637329fa233df506d7b97802a5e9b (diff) |
Add Mirrored distribution strategy support for new metrics with Keras and Estimator
Add support for stateful metrics in model to estimator
PiperOrigin-RevId: 214714322
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/keras.py | 39 | ||||
-rw-r--r-- | tensorflow/python/estimator/keras_test.py | 28 |
2 files changed, 42 insertions, 25 deletions
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index 6b2765be82..7546771ed3 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import re +import six from tensorflow.python.client import session from tensorflow.python.estimator import estimator as estimator_lib @@ -31,6 +32,7 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K +from tensorflow.python.keras import metrics from tensorflow.python.keras import models from tensorflow.python.keras import optimizers from tensorflow.python.ops import check_ops @@ -214,25 +216,40 @@ def _convert_keras_metrics_to_estimator(model): if not getattr(model, 'metrics', None): return None - # TODO(psv/fchollet): support stateful metrics eval_metric_ops = {} + + def get_metric_name(metric): + if isinstance(metric, metrics.Metric): + return metric.name + if callable(metric): + return metric.__name__ + assert isinstance(metric, six.string_types) + return metric + # When each metric maps to an output if isinstance(model.metrics, dict): for i, output_name in enumerate(model.metrics.keys()): - metric_name = model.metrics[output_name] - if callable(metric_name): - metric_name = metric_name.__name__ + # `metric` is the user given metric value in `compile`. This can be + # metric name (`acc`), metric function (binary_accuracy) or a metric + # object (BinaryAccuracy()). + metric = model.metrics[output_name] + metric_name = get_metric_name(metric) # When some outputs use the same metric if list(model.metrics.values()).count(metric_name) > 1: metric_name += '_' + output_name - eval_metric_ops[metric_name] = metrics_module.mean( - model.metrics_tensors[i - len(model.metrics)]) + if isinstance(metric, metrics.Metric): + eval_metric_ops[metric_name] = metric + else: + eval_metric_ops[metric_name] = metrics_module.mean( + model.metrics_tensors[i - len(model.metrics)]) else: - for i, metric_name in enumerate(model.metrics): - if callable(metric_name): - metric_name = metric_name.__name__ - eval_metric_ops[metric_name] = metrics_module.mean( - model.metrics_tensors[i]) + for i, metric in enumerate(model.metrics): + metric_name = get_metric_name(metric) + if isinstance(metric, metrics.Metric): + eval_metric_ops[metric_name] = metric + else: + eval_metric_ops[metric_name] = metrics_module.mean( + model.metrics_tensors[i]) return eval_metric_ops diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py index 3758243d7b..288f9b8906 100644 --- a/tensorflow/python/estimator/keras_test.py +++ b/tensorflow/python/estimator/keras_test.py @@ -257,7 +257,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer='rmsprop', - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) with self.cached_session(): est_keras = keras_lib.model_to_estimator( @@ -281,7 +281,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer=rmsprop.RMSPropOptimizer(1e-3), - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) my_hook = MyHook() with self.cached_session(): @@ -306,7 +306,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer=rmsprop.RMSPropOptimizer(1e-3), - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) my_hook = MyHook() with self.cached_session(): keras_model.fit(x_train, y_train, epochs=1) @@ -328,7 +328,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer=rmsprop.RMSPropOptimizer(1e-3), - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) with self.cached_session(): est_keras = keras_lib.model_to_estimator( @@ -351,7 +351,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer=rmsprop.RMSPropOptimizer(1e-3), - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) with self.cached_session(): est_keras = keras_lib.model_to_estimator( @@ -370,7 +370,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer=rmsprop.RMSPropOptimizer(1e-3), - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) with self.cached_session(): # Create state @@ -662,7 +662,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer='rmsprop', - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) tf_config = json.dumps({ 'cluster': { @@ -687,7 +687,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer='rmsprop', - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3) sess_config = config_pb2.ConfigProto(gpu_options=gpu_options) @@ -706,7 +706,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer='rmsprop', - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) with self.cached_session(): est_keras = keras_lib.model_to_estimator( @@ -736,7 +736,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer='rmsprop', - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) with self.cached_session(): with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR): @@ -751,7 +751,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer='rmsprop', - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) with self.cached_session(): with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in ' @@ -765,7 +765,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer=rmsprop.RMSPropOptimizer(1e-3), - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) with self.cached_session(): keras_model.train_on_batch( np.random.random((10,) + _INPUT_SIZE), @@ -776,7 +776,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer=SGD(lr=0.0001, momentum=0.9), - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) @@ -786,7 +786,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.compile( loss='categorical_crossentropy', optimizer=optimizer, - metrics=['mse', keras.metrics.categorical_accuracy]) + metrics=['mse', keras.metrics.CategoricalAccuracy()]) with self.cached_session() as sess: keras_model_fn = keras_lib._create_keras_model_fn(keras_model) global_step = training_util.create_global_step() |