From 100995e2d0fbe56e19a6472d6f921ff0680fdfd2 Mon Sep 17 00:00:00 2001 From: Rui Zhao Date: Mon, 2 Jul 2018 10:54:39 -0700 Subject: Make it possible to serialize Topology class that is created without a serialized topology. PiperOrigin-RevId: 202978167 --- tensorflow/contrib/tpu/BUILD | 10 +++++ tensorflow/contrib/tpu/python/tpu/topology.py | 5 ++- tensorflow/contrib/tpu/python/tpu/topology_test.py | 46 ++++++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 tensorflow/contrib/tpu/python/tpu/topology_test.py diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 16696793bc..c08f088be7 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -307,3 +307,13 @@ tf_py_test( "//tensorflow/python:framework_test_lib", ], ) + +tf_py_test( + name = "topology_test", + size = "small", + srcs = ["python/tpu/topology_test.py"], + additional_deps = [ + ":tpu", + "//tensorflow/python:framework_test_lib", + ], +) diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index cda9a63f20..1fb26e701a 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -55,8 +55,9 @@ class Topology(object): rank 3 numpy int32 array that describes a valid coordinate mapping. """ + self._serialized = serialized + if serialized: - self._serialized = serialized self._parse_topology(serialized) else: self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32) @@ -131,7 +132,7 @@ class Topology(object): proto.mesh_shape[:] = list(self._mesh_shape) proto.num_tasks = self._device_coordinates.shape[0] proto.num_tpu_devices_per_task = self._device_coordinates.shape[1] - proto.device_coordinates = list(self._device_coordinates.flatten()) + proto.device_coordinates.extend(list(self._device_coordinates.flatten())) self._serialized = proto.SerializeToString() return self._serialized diff --git a/tensorflow/contrib/tpu/python/tpu/topology_test.py b/tensorflow/contrib/tpu/python/tpu/topology_test.py new file mode 100644 index 0000000000..e67fdb263a --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/topology_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for topology.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import topology + +from tensorflow.python.platform import test + + +class TopologyTest(test.TestCase): + + def testSerialization(self): + """Test if the class is able to generate serialzied string.""" + original_topology = topology.Topology( + mesh_shape=[1, 1, 2], + device_coordinates=[[[0, 0, 0], [0, 0, 1]]], + ) + serialized_str = original_topology.serialized() + new_topology = topology.Topology(serialized=serialized_str) + + # Make sure the topology recovered from serialized str is same as the + # original topology. + self.assertAllEqual( + original_topology.mesh_shape, new_topology.mesh_shape) + self.assertAllEqual( + original_topology.device_coordinates, new_topology.device_coordinates) + +if __name__ == "__main__": + test.main() -- cgit v1.2.3