diff options
author | 2017-08-11 15:41:54 -0700 | |
---|---|---|
committer | 2017-08-11 15:48:01 -0700 | |
commit | fe762ceaf7e5e0ad82fb7a3a05f114c8f5b8d429 (patch) | |
tree | 3e713235abc907ff8055182902a3a5e7d076f7cc /tensorflow/contrib/training | |
parent | 49f5fd91a47ce0578b19cb5a36865f3890dddb68 (diff) |
Fix unicode error in device_setter_test.
PiperOrigin-RevId: 165035012
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r-- | tensorflow/contrib/training/python/training/device_setter.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/training/python/training/device_setter_test.py | 9 |
2 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/training/python/training/device_setter.py b/tensorflow/contrib/training/python/training/device_setter.py index e324acb754..231fc5788f 100644 --- a/tensorflow/contrib/training/python/training/device_setter.py +++ b/tensorflow/contrib/training/python/training/device_setter.py @@ -46,6 +46,7 @@ class RandomStrategy(object): def __call__(self, op): """Chooses a ps task index for the given `Operation`.""" key = "%s_%d" % (op.name, self._seed) + key = key.encode("utf-8") # Use MD5 instead of Python's built-in hash() to get consistent outputs # between runs. n = int(hashlib.md5(key).hexdigest(), 16) diff --git a/tensorflow/contrib/training/python/training/device_setter_test.py b/tensorflow/contrib/training/python/training/device_setter_test.py index 9d6572e39a..20746d911c 100644 --- a/tensorflow/contrib/training/python/training/device_setter_test.py +++ b/tensorflow/contrib/training/python/training/device_setter_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections from tensorflow.contrib.training.python.training import device_setter as device_setter_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -30,6 +31,8 @@ _CLUSTER_SPEC = server_lib.ClusterSpec({ "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] }) +MockOperation = collections.namedtuple("MockOperation", "name") + class RandomStrategyTest(test.TestCase): @@ -55,6 +58,12 @@ class RandomStrategyTest(test.TestCase): self.assertDeviceEqual("/job:ps/task:1", x.initializer.device) self.assertDeviceEqual("/job:worker", a.device) + def testHandlesUnicode(self): + op = MockOperation(u"A unicode \u018e string \xf1") + ps_strategy = device_setter_lib.RandomStrategy(2, seed=0) + ps_task = ps_strategy(op) + self.assertEqual(ps_task, 1) + class GreedyLoadBalancingStrategyTest(test.TestCase): |