diff options
author | 2017-08-28 12:02:32 -0700 | |
---|---|---|
committer | 2017-08-28 12:06:33 -0700 | |
commit | 45c3602cf43104be7073b4041aa049b3bb4390fa (patch) | |
tree | d080d536eb66c62c27af6b652e17b6eaa0279c5f /tensorflow/python/util/tf_export_test.py | |
parent | 95d240c5fbecec9fdbef55dc1154c4f454752633 (diff) |
Adding tf_export decorator to support exporting tensorflow ops and constants.
PiperOrigin-RevId: 166736891
Diffstat (limited to 'tensorflow/python/util/tf_export_test.py')
-rw-r--r-- | tensorflow/python/util/tf_export_test.py | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py new file mode 100644 index 0000000000..3b7636c34e --- /dev/null +++ b/tensorflow/python/util/tf_export_test.py @@ -0,0 +1,157 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tf_export tests.""" + +# pylint: disable=unused-import +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.python.platform import test +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_export + + +def _test_function(unused_arg=0): + pass + + +def _test_function2(unused_arg=0): + pass + + +class TestClassA(object): + pass + + +class TestClassB(TestClassA): + pass + + +class ValidateExportTest(test.TestCase): + """Tests for tf_export class.""" + + class MockModule(object): + + def __init__(self, name): + self.__name__ = name + + def setUp(self): + self._modules = [] + + def tearDown(self): + for name in self._modules: + del sys.modules[name] + self._modules = [] + for symbol in [_test_function, _test_function, TestClassA, TestClassB]: + if hasattr(symbol, '_tf_api_names'): + del symbol._tf_api_names + + def _CreateMockModule(self, name): + mock_module = self.MockModule(name) + sys.modules[name] = mock_module + self._modules.append(name) + return mock_module + + def testExportSingleFunction(self): + export_decorator = tf_export.tf_export('nameA', 'nameB') + decorated_function = export_decorator(_test_function) + self.assertEquals(decorated_function, _test_function) + self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names) + + def testExportMultipleFunctions(self): + export_decorator1 = tf_export.tf_export('nameA', 'nameB') + export_decorator2 = tf_export.tf_export('nameC', 'nameD') + decorated_function1 = export_decorator1(_test_function) + decorated_function2 = export_decorator2(_test_function2) + self.assertEquals(decorated_function1, _test_function) + self.assertEquals(decorated_function2, _test_function2) + self.assertEquals(('nameA', 'nameB'), decorated_function1._tf_api_names) + self.assertEquals(('nameC', 'nameD'), decorated_function2._tf_api_names) + + def testExportClasses(self): + export_decorator_a = tf_export.tf_export('TestClassA1') + export_decorator_a(TestClassA) + self.assertEquals(('TestClassA1',), TestClassA._tf_api_names) + self.assertTrue('_tf_api_names' not in TestClassB.__dict__) + + export_decorator_b = tf_export.tf_export('TestClassB1') + export_decorator_b(TestClassB) + self.assertEquals(('TestClassA1',), TestClassA._tf_api_names) + self.assertEquals(('TestClassB1',), TestClassB._tf_api_names) + + def testExportSingleConstant(self): + module1 = self._CreateMockModule('module1') + + test_constant = 123 + export_decorator = tf_export.tf_export('NAME_A', 'NAME_B') + export_decorator.export_constant('module1', test_constant) + self.assertEquals([(('NAME_A', 'NAME_B'), 123)], + module1._tf_api_constants) + + def testExportMultipleConstants(self): + module1 = self._CreateMockModule('module1') + module2 = self._CreateMockModule('module2') + + test_constant1 = 123 + test_constant2 = 'abc' + test_constant3 = 0.5 + + export_decorator1 = tf_export.tf_export('NAME_A', 'NAME_B') + export_decorator2 = tf_export.tf_export('NAME_C', 'NAME_D') + export_decorator3 = tf_export.tf_export('NAME_E', 'NAME_F') + export_decorator1.export_constant('module1', test_constant1) + export_decorator2.export_constant('module2', test_constant2) + export_decorator3.export_constant('module2', test_constant3) + self.assertEquals([(('NAME_A', 'NAME_B'), 123)], + module1._tf_api_constants) + self.assertEquals([(('NAME_C', 'NAME_D'), 'abc'), + (('NAME_E', 'NAME_F'), 0.5)], + module2._tf_api_constants) + + def testRaisesExceptionIfAlreadyHasAPINames(self): + _test_function._tf_api_names = ['abc'] + export_decorator = tf_export.tf_export('nameA', 'nameB') + with self.assertRaises(tf_export.SymbolAlreadyExposedError): + export_decorator(_test_function) + + def testOverridesFunction(self): + _test_function2._tf_api_names = ['abc'] + + export_decorator = tf_export.tf_export( + 'nameA', 'nameB', overrides=[_test_function2]) + export_decorator(_test_function) + + # _test_function overrides _test_function2. So, _tf_api_names + # should be removed from _test_function2. + self.assertFalse(hasattr(_test_function2, '_tf_api_names')) + + def testMultipleDecorators(self): + def get_wrapper(func): + def wrapper(*unused_args, **unused_kwargs): + pass + return tf_decorator.make_decorator(func, wrapper) + decorated_function = get_wrapper(_test_function) + + export_decorator = tf_export.tf_export('nameA', 'nameB') + exported_function = export_decorator(decorated_function) + self.assertEquals(decorated_function, exported_function) + self.assertEquals(('nameA', 'nameB'), _test_function._tf_api_names) + + +if __name__ == '__main__': + test.main() |