aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/registry_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/registry_test.py')
-rw-r--r--tensorflow/python/framework/registry_test.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/python/framework/registry_test.py b/tensorflow/python/framework/registry_test.py
new file mode 100644
index 0000000000..5b4f261ceb
--- /dev/null
+++ b/tensorflow/python/framework/registry_test.py
@@ -0,0 +1,38 @@
+"""Tests for tensorflow.ops.registry."""
+
+from tensorflow.python.framework import registry
+from tensorflow.python.platform import googletest
+
+
+class RegistryTest(googletest.TestCase):
+
+ class Foo(object):
+ pass
+
+ def testRegisterClass(self):
+ myreg = registry.Registry('testfoo')
+ with self.assertRaises(LookupError):
+ myreg.lookup('Foo')
+ myreg.register(RegistryTest.Foo, 'Foo')
+ assert myreg.lookup('Foo') == RegistryTest.Foo
+
+ def testRegisterFunction(self):
+ myreg = registry.Registry('testbar')
+ with self.assertRaises(LookupError):
+ myreg.lookup('Bar')
+ myreg.register(bar, 'Bar')
+ assert myreg.lookup('Bar') == bar
+
+ def testDuplicate(self):
+ myreg = registry.Registry('testbar')
+ myreg.register(bar, 'Bar')
+ with self.assertRaises(KeyError):
+ myreg.register(bar, 'Bar')
+
+
+def bar():
+ pass
+
+
+if __name__ == '__main__':
+ googletest.main()