aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/api/generator/create_python_api_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/tools/api/generator/create_python_api_test.py')
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api_test.py99
1 files changed, 99 insertions, 0 deletions
diff --git a/tensorflow/python/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py
new file mode 100644
index 0000000000..a565a49d96
--- /dev/null
+++ b/tensorflow/python/tools/api/generator/create_python_api_test.py
@@ -0,0 +1,99 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for create_python_api."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import imp
+import sys
+
+from tensorflow.python.platform import test
+from tensorflow.python.tools.api.generator import create_python_api
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export('test_op', 'test_op1')
+def test_op():
+ pass
+
+
+@tf_export('TestClass', 'NewTestClass')
+class TestClass(object):
+ pass
+
+
+_TEST_CONSTANT = 5
+_MODULE_NAME = 'tensorflow.python.test_module'
+
+
+class CreatePythonApiTest(test.TestCase):
+
+ def setUp(self):
+ # Add fake op to a module that has 'tensorflow' in the name.
+ sys.modules[_MODULE_NAME] = imp.new_module(_MODULE_NAME)
+ setattr(sys.modules[_MODULE_NAME], 'test_op', test_op)
+ setattr(sys.modules[_MODULE_NAME], 'TestClass', TestClass)
+ test_op.__module__ = _MODULE_NAME
+ TestClass.__module__ = _MODULE_NAME
+ tf_export('consts._TEST_CONSTANT').export_constant(
+ _MODULE_NAME, '_TEST_CONSTANT')
+
+ def tearDown(self):
+ del sys.modules[_MODULE_NAME]
+
+ def testFunctionImportIsAdded(self):
+ imports = create_python_api.get_api_init_text(
+ package=create_python_api._DEFAULT_PACKAGE,
+ output_package='tensorflow',
+ api_name='tensorflow', api_version=1)
+ expected_import = (
+ 'from tensorflow.python.test_module '
+ 'import test_op as test_op1')
+ self.assertTrue(
+ expected_import in str(imports),
+ msg='%s not in %s' % (expected_import, str(imports)))
+
+ expected_import = ('from tensorflow.python.test_module '
+ 'import test_op')
+ self.assertTrue(
+ expected_import in str(imports),
+ msg='%s not in %s' % (expected_import, str(imports)))
+
+ def testClassImportIsAdded(self):
+ imports = create_python_api.get_api_init_text(
+ package=create_python_api._DEFAULT_PACKAGE,
+ output_package='tensorflow',
+ api_name='tensorflow', api_version=2)
+ expected_import = ('from tensorflow.python.test_module '
+ 'import TestClass')
+ self.assertTrue(
+ 'TestClass' in str(imports),
+ msg='%s not in %s' % (expected_import, str(imports)))
+
+ def testConstantIsAdded(self):
+ imports = create_python_api.get_api_init_text(
+ package=create_python_api._DEFAULT_PACKAGE,
+ output_package='tensorflow',
+ api_name='tensorflow', api_version=1)
+ expected = ('from tensorflow.python.test_module '
+ 'import _TEST_CONSTANT')
+ self.assertTrue(expected in str(imports),
+ msg='%s not in %s' % (expected, str(imports)))
+
+
+if __name__ == '__main__':
+ test.main()