diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-09-19 11:43:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 11:46:33 -0700 |
commit | ebe769f166c35c16637cb919ea3ddd096e04befa (patch) | |
tree | b73bb179081d2347d32c6a9e0aba18291e7a59aa /tensorflow/python/keras | |
parent | 1b4999df0c2ef3c8c7d771415924fb58a5476c6a (diff) |
Re-enable flaky keras_test
PiperOrigin-RevId: 213665390
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 6 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_distributed.py | 6 |
2 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index dc464c02b6..7df72d45b4 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -422,8 +422,9 @@ class Model(Network): # Set DistributionStrategy specific parameters. self._distribution_strategy = distribute + # Reset the value of grouped_model + self._grouped_model = None if self._distribution_strategy is not None: - self._grouped_model = None distributed_training_utils.configure_and_create_session( self._distribution_strategy) if not self.built: @@ -445,7 +446,8 @@ class Model(Network): for name in self.output_names: if name not in loss: logging.warning( - 'Output "' + name + '" missing from loss dictionary. We assume ' + 'Output "' + name + + '" missing from loss dictionary. We assume ' 'this was done on purpose. The fit and evaluate APIs will not be ' 'expecting any data to be passed to "' + name + '".') loss_functions.append(losses.get(loss.get(name))) diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 53291c3956..d133595793 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import errors from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks as cbks @@ -742,8 +743,9 @@ def _experimental_predict_loop(model, iterator, verbose=0, steps=None): for name, tensor in zip(model.output_names, model.outputs): # TODO(priyag): This is a workaround as we do not know the batch dimension # of the model's output at this point. - tensor.shape.dims = [batch_dimension] + tensor.shape.dims[1:] - initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) + shape = tensor_shape.TensorShape(tensor.shape.dims) + shape.dims = [batch_dimension] + shape.dims[1:] + initial_loop_values[name] = array_ops.zeros(shape, tensor.dtype) with current_strategy.scope(): # TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed. |