aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/keras.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/keras.py')
-rw-r--r--tensorflow/python/estimator/keras.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index c63deb8f4d..a5f07fea3b 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -44,6 +44,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -358,6 +359,14 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
def model_fn(features, labels, mode):
"""model_fn for keras Estimator."""
+ # Raise an error when users use DistributionStrategy with native Keras
+ # optimizers. Currently we only support native TensorFlow optimizers.
+ if distribute_lib.has_distribution_strategy() and \
+ not isinstance(keras_model.optimizer,
+ (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
+ raise ValueError('Only TensorFlow native optimizers are supported with '
+ 'DistributionStrategy.')
+
model = _clone_and_build_model(mode, keras_model, custom_objects, features,
labels)
model_output_names = []
@@ -387,7 +396,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
loss = model.total_loss
if model.metrics:
- # TODO(fchollet): support stateful metrics
+ # TODO(psv/fchollet): support stateful metrics
eval_metric_ops = {}
# When each metric maps to an output
if isinstance(model.metrics, dict):