diff options
Diffstat (limited to 'tensorflow/python/estimator/keras.py')
-rw-r--r-- | tensorflow/python/estimator/keras.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index c63deb8f4d..a5f07fea3b 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -44,6 +44,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import checkpoint_management from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import base as checkpointable @@ -358,6 +359,14 @@ def _create_keras_model_fn(keras_model, custom_objects=None): def model_fn(features, labels, mode): """model_fn for keras Estimator.""" + # Raise an error when users use DistributionStrategy with native Keras + # optimizers. Currently we only support native TensorFlow optimizers. + if distribute_lib.has_distribution_strategy() and \ + not isinstance(keras_model.optimizer, + (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): + raise ValueError('Only TensorFlow native optimizers are supported with ' + 'DistributionStrategy.') + model = _clone_and_build_model(mode, keras_model, custom_objects, features, labels) model_output_names = [] @@ -387,7 +396,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None): loss = model.total_loss if model.metrics: - # TODO(fchollet): support stateful metrics + # TODO(psv/fchollet): support stateful metrics eval_metric_ops = {} # When each metric maps to an output if isinstance(model.metrics, dict): |