aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-09-19 11:43:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 11:46:33 -0700
commitebe769f166c35c16637cb919ea3ddd096e04befa (patch)
treeb73bb179081d2347d32c6a9e0aba18291e7a59aa /tensorflow/python/keras
parent1b4999df0c2ef3c8c7d771415924fb58a5476c6a (diff)
Re-enable flaky keras_test
PiperOrigin-RevId: 213665390
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/engine/training.py6
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py6
2 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index dc464c02b6..7df72d45b4 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -422,8 +422,9 @@ class Model(Network):
# Set DistributionStrategy specific parameters.
self._distribution_strategy = distribute
+ # Reset the value of grouped_model
+ self._grouped_model = None
if self._distribution_strategy is not None:
- self._grouped_model = None
distributed_training_utils.configure_and_create_session(
self._distribution_strategy)
if not self.built:
@@ -445,7 +446,8 @@ class Model(Network):
for name in self.output_names:
if name not in loss:
logging.warning(
- 'Output "' + name + '" missing from loss dictionary. We assume '
+ 'Output "' + name +
+ '" missing from loss dictionary. We assume '
'this was done on purpose. The fit and evaluate APIs will not be '
'expecting any data to be passed to "' + name + '".')
loss_functions.append(losses.get(loss.get(name)))
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 53291c3956..d133595793 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
@@ -742,8 +743,9 @@ def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
for name, tensor in zip(model.output_names, model.outputs):
# TODO(priyag): This is a workaround as we do not know the batch dimension
# of the model's output at this point.
- tensor.shape.dims = [batch_dimension] + tensor.shape.dims[1:]
- initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+ shape = tensor_shape.TensorShape(tensor.shape.dims)
+ shape.dims = [batch_dimension] + shape.dims[1:]
+ initial_loop_values[name] = array_ops.zeros(shape, tensor.dtype)
with current_strategy.scope():
# TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.