aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/api/generator/doc_srcs_test.py
blob: 481d9874a4bcdcdadcdcb16b5b5c1b10b765dc48 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for tensorflow.python.tools.api.generator.doc_srcs."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import importlib
import sys

from tensorflow.python.platform import test
from tensorflow.python.tools.api.generator import doc_srcs


FLAGS = None


class DocSrcsTest(test.TestCase):

  def testModulesAreValidAPIModules(self):
    for module_name in doc_srcs.get_doc_sources(FLAGS.api_name):
      # Convert module_name to corresponding __init__.py file path.
      file_path = module_name.replace('.', '/')
      if file_path:
        file_path += '/'
      file_path += '__init__.py'

      self.assertIn(
          file_path, FLAGS.outputs,
          msg='%s is not a valid API module' % module_name)

  def testHaveDocstringOrDocstringModule(self):
    for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
      self.assertFalse(
          docsrc.docstring and docsrc.docstring_module_name,
          msg=('%s contains DocSource has both a docstring and a '
               'docstring_module_name. Only one of "docstring" or '
               '"docstring_module_name" should be set.') % (module_name))

  def testDocstringModulesAreValidModules(self):
    for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
      if docsrc.docstring_module_name:
        doc_module_name = '.'.join([
            FLAGS.package, docsrc.docstring_module_name])
        self.assertIn(
            doc_module_name, sys.modules,
            msg=('docsources_module %s is not a valid module under %s.' %
                 (docsrc.docstring_module_name, FLAGS.package)))


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      'outputs', metavar='O', type=str, nargs='+',
      help='create_python_api output files.')
  parser.add_argument(
      '--package', type=str,
      help='Base package that imports modules containing the target tf_export '
           'decorators.')
  parser.add_argument(
      '--api_name', type=str,
      help='API name: tensorflow or estimator')
  FLAGS, unparsed = parser.parse_known_args()

  importlib.import_module(FLAGS.package)

  # Now update argv, so that unittest library does not get confused.
  sys.argv = [sys.argv[0]] + unparsed
  test.main()