diff options
author | 2018-08-16 10:52:39 -0700 | |
---|---|---|
committer | 2018-08-16 11:19:05 -0700 | |
commit | db9f9f5a2cd3e0434225f05d65d9de8a4b5e9d41 (patch) | |
tree | 8e7cd82bab96691f3759db238943fedba1f1f269 /tensorflow/contrib/tpu | |
parent | 394db95965e1d745f08b4eeb550878ddc175af15 (diff) |
[Keras / Cloud TPU] Fix ClusterSpec propagation bug.
PiperOrigin-RevId: 209010536
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index ff893a722f..30cfae1da9 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -872,14 +872,14 @@ class KerasTPUModel(models.Model): tpu_name_or_address) master = self._cluster_resolver.master() cluster_spec = self._cluster_resolver.cluster_spec() + config = config_pb2.ConfigProto(isolate_session_state=True) + if cluster_spec: + config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + self._session = tf_session.Session( graph=self._graph, target=master, - config=config_pb2.ConfigProto(isolate_session_state=True)) - - # TODO(saeta): Confirm the lines below work in ClusterSpec propagation env. - if cluster_spec: - self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + config=config) with self._graph.as_default(): self._session.run(tpu.initialize_system()) |