blob: 5b4f261cebe56d21848251fcecd9f35a1c9b4976 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
"""Tests for tensorflow.ops.registry."""
from tensorflow.python.framework import registry
from tensorflow.python.platform import googletest
class RegistryTest(googletest.TestCase):
class Foo(object):
pass
def testRegisterClass(self):
myreg = registry.Registry('testfoo')
with self.assertRaises(LookupError):
myreg.lookup('Foo')
myreg.register(RegistryTest.Foo, 'Foo')
assert myreg.lookup('Foo') == RegistryTest.Foo
def testRegisterFunction(self):
myreg = registry.Registry('testbar')
with self.assertRaises(LookupError):
myreg.lookup('Bar')
myreg.register(bar, 'Bar')
assert myreg.lookup('Bar') == bar
def testDuplicate(self):
myreg = registry.Registry('testbar')
myreg.register(bar, 'Bar')
with self.assertRaises(KeyError):
myreg.register(bar, 'Bar')
def bar():
pass
if __name__ == '__main__':
googletest.main()
|