aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/registry_test.py
blob: 5b4f261cebe56d21848251fcecd9f35a1c9b4976 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Tests for tensorflow.ops.registry."""

from tensorflow.python.framework import registry
from tensorflow.python.platform import googletest


class RegistryTest(googletest.TestCase):

  class Foo(object):
    pass

  def testRegisterClass(self):
    myreg = registry.Registry('testfoo')
    with self.assertRaises(LookupError):
      myreg.lookup('Foo')
    myreg.register(RegistryTest.Foo, 'Foo')
    assert myreg.lookup('Foo') == RegistryTest.Foo

  def testRegisterFunction(self):
    myreg = registry.Registry('testbar')
    with self.assertRaises(LookupError):
      myreg.lookup('Bar')
    myreg.register(bar, 'Bar')
    assert myreg.lookup('Bar') == bar

  def testDuplicate(self):
    myreg = registry.Registry('testbar')
    myreg.register(bar, 'Bar')
    with self.assertRaises(KeyError):
      myreg.register(bar, 'Bar')


def bar():
  pass


if __name__ == '__main__':
  googletest.main()