aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-08-16 10:52:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 11:19:05 -0700
commitdb9f9f5a2cd3e0434225f05d65d9de8a4b5e9d41 (patch)
tree8e7cd82bab96691f3759db238943fedba1f1f269 /tensorflow/contrib/tpu
parent394db95965e1d745f08b4eeb550878ddc175af15 (diff)
[Keras / Cloud TPU] Fix ClusterSpec propagation bug.
PiperOrigin-RevId: 209010536
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index ff893a722f..30cfae1da9 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -872,14 +872,14 @@ class KerasTPUModel(models.Model):
tpu_name_or_address)
master = self._cluster_resolver.master()
cluster_spec = self._cluster_resolver.cluster_spec()
+ config = config_pb2.ConfigProto(isolate_session_state=True)
+ if cluster_spec:
+ config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+
self._session = tf_session.Session(
graph=self._graph,
target=master,
- config=config_pb2.ConfigProto(isolate_session_state=True))
-
- # TODO(saeta): Confirm the lines below work in ClusterSpec propagation env.
- if cluster_spec:
- self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+ config=config)
with self._graph.as_default():
self._session.run(tpu.initialize_system())