aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-11 15:41:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-11 15:48:01 -0700
commitfe762ceaf7e5e0ad82fb7a3a05f114c8f5b8d429 (patch)
tree3e713235abc907ff8055182902a3a5e7d076f7cc /tensorflow/contrib/training
parent49f5fd91a47ce0578b19cb5a36865f3890dddb68 (diff)
Fix unicode error in device_setter_test.
PiperOrigin-RevId: 165035012
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/device_setter.py1
-rw-r--r--tensorflow/contrib/training/python/training/device_setter_test.py9
2 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/training/python/training/device_setter.py b/tensorflow/contrib/training/python/training/device_setter.py
index e324acb754..231fc5788f 100644
--- a/tensorflow/contrib/training/python/training/device_setter.py
+++ b/tensorflow/contrib/training/python/training/device_setter.py
@@ -46,6 +46,7 @@ class RandomStrategy(object):
def __call__(self, op):
"""Chooses a ps task index for the given `Operation`."""
key = "%s_%d" % (op.name, self._seed)
+ key = key.encode("utf-8")
# Use MD5 instead of Python's built-in hash() to get consistent outputs
# between runs.
n = int(hashlib.md5(key).hexdigest(), 16)
diff --git a/tensorflow/contrib/training/python/training/device_setter_test.py b/tensorflow/contrib/training/python/training/device_setter_test.py
index 9d6572e39a..20746d911c 100644
--- a/tensorflow/contrib/training/python/training/device_setter_test.py
+++ b/tensorflow/contrib/training/python/training/device_setter_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
from tensorflow.contrib.training.python.training import device_setter as device_setter_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -30,6 +31,8 @@ _CLUSTER_SPEC = server_lib.ClusterSpec({
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
+MockOperation = collections.namedtuple("MockOperation", "name")
+
class RandomStrategyTest(test.TestCase):
@@ -55,6 +58,12 @@ class RandomStrategyTest(test.TestCase):
self.assertDeviceEqual("/job:ps/task:1", x.initializer.device)
self.assertDeviceEqual("/job:worker", a.device)
+ def testHandlesUnicode(self):
+ op = MockOperation(u"A unicode \u018e string \xf1")
+ ps_strategy = device_setter_lib.RandomStrategy(2, seed=0)
+ ps_task = ps_strategy(op)
+ self.assertEqual(ps_task, 1)
+
class GreedyLoadBalancingStrategyTest(test.TestCase):