diff options
30 files changed, 251 insertions, 40 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD index a4e7602bea..957528cecb 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -830,3 +830,14 @@ py_library( visibility = ["//visibility:public"], deps = ["//tensorflow/python"], ) + +py_library( + name = "experimental_tensorflow_py", + srcs = ["experimental_api.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow/tools/api/tests:__subpackages__"], + deps = [ + "//tensorflow/python", + "//tensorflow/tools/api/generator:python_api", + ], +) diff --git a/tensorflow/core/api_def/python_api/api_def_Assign.pbtxt b/tensorflow/core/api_def/python_api/api_def_Assign.pbtxt new file mode 100644 index 0000000000..34062ede91 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_Assign.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "Assign" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_AssignAdd.pbtxt b/tensorflow/core/api_def/python_api/api_def_AssignAdd.pbtxt new file mode 100644 index 0000000000..4553c6c6e7 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_AssignAdd.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "AssignAdd" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_AssignSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_AssignSub.pbtxt new file mode 100644 index 0000000000..aec68d5c21 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_AssignSub.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "AssignSub" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_SparseReduceMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_SparseReduceMax.pbtxt new file mode 100644 index 0000000000..a885e97d23 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SparseReduceMax.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SparseReduceMax" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_SparseReduceMaxSparse.pbtxt b/tensorflow/core/api_def/python_api/api_def_SparseReduceMaxSparse.pbtxt new file mode 100644 index 0000000000..c7d6978801 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SparseReduceMaxSparse.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SparseReduceMaxSparse" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_SparseReduceSum.pbtxt b/tensorflow/core/api_def/python_api/api_def_SparseReduceSum.pbtxt new file mode 100644 index 0000000000..ee30d7aaf1 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SparseReduceSum.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SparseReduceSum" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_SparseReduceSumSparse.pbtxt b/tensorflow/core/api_def/python_api/api_def_SparseReduceSumSparse.pbtxt new file mode 100644 index 0000000000..0dd89fb4c5 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SparseReduceSumSparse.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SparseReduceSumSparse" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_SparseSlice.pbtxt b/tensorflow/core/api_def/python_api/api_def_SparseSlice.pbtxt new file mode 100644 index 0000000000..716f8781d0 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SparseSlice.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SparseSlice" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_SparseSoftmax.pbtxt b/tensorflow/core/api_def/python_api/api_def_SparseSoftmax.pbtxt new file mode 100644 index 0000000000..fc29dd6513 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SparseSoftmax.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SparseSoftmax" + visibility: HIDDEN +} diff --git a/tensorflow/experimental_api.py b/tensorflow/experimental_api.py new file mode 100644 index 0000000000..63a8aa9cb1 --- /dev/null +++ b/tensorflow/experimental_api.py @@ -0,0 +1,38 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Bring in all of the public TensorFlow interface into this +# module. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# pylint: disable=wildcard-import +from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin +# pylint: enable=wildcard-import + +from tensorflow.python.util.lazy_loader import LazyLoader +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader + +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + +del absolute_import +del division +del print_function diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 99ae8b24f1..0edae92fd4 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -343,7 +343,9 @@ tf_export("uint8").export_constant(__name__, "uint8") uint16 = DType(types_pb2.DT_UINT16) tf_export("uint16").export_constant(__name__, "uint16") uint32 = DType(types_pb2.DT_UINT32) +tf_export("uint32").export_constant(__name__, "uint32") uint64 = DType(types_pb2.DT_UINT64) +tf_export("uint64").export_constant(__name__, "uint32") int16 = DType(types_pb2.DT_INT16) tf_export("int16").export_constant(__name__, "int16") int8 = DType(types_pb2.DT_INT8) diff --git a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py index b9ae41a0d4..508e95f719 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py +++ b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py @@ -24,8 +24,10 @@ import os import numpy as np from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.datasets.fashion_mnist.load_data') def load_data(): """Loads the Fashion-MNIST dataset. diff --git a/tensorflow/python/keras/_impl/keras/engine/input_layer.py b/tensorflow/python/keras/_impl/keras/engine/input_layer.py index 29a17555e0..b51dd8a218 100644 --- a/tensorflow/python/keras/_impl/keras/engine/input_layer.py +++ b/tensorflow/python/keras/_impl/keras/engine/input_layer.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.InputLayer') class InputLayer(base_layer.Layer): """Layer to be used as an entry point into a Network (a graph of layers). diff --git a/tensorflow/python/keras/datasets/fashion_mnist/__init__.py b/tensorflow/python/keras/datasets/fashion_mnist/__init__.py index e69de29bb2..7f5ddecc47 100644 --- a/tensorflow/python/keras/datasets/fashion_mnist/__init__.py +++ b/tensorflow/python/keras/datasets/fashion_mnist/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fashion-MNIST dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras._impl.keras.datasets.fashion_mnist import load_data + +del absolute_import +del division +del print_function diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 702e47d28f..3369fe3c9b 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -1795,6 +1795,7 @@ _rgb_to_yiq_kernel = [[0.299, 0.59590059, [0.114, -0.32134392, 0.31119955]] +@tf_export('image.rgb_to_yiq') def rgb_to_yiq(images): """Converts one or more images from RGB to YIQ. @@ -1820,6 +1821,7 @@ _yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021], [0.6208248, -0.64720424, 1.70423049]] +@tf_export('image.yiq_to_rgb') def yiq_to_rgb(images): """Converts one or more images from YIQ to RGB. @@ -1847,6 +1849,7 @@ _rgb_to_yuv_kernel = [[0.299, -0.14714119, [0.114, 0.43601035, -0.10001026]] +@tf_export('image.rgb_to_yuv') def rgb_to_yuv(images): """Converts one or more images from RGB to YUV. @@ -1872,6 +1875,7 @@ _yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185], [1.13988303, -0.58062185, 0]] +@tf_export('image.yuv_to_rgb') def yuv_to_rgb(images): """Converts one or more images from YUV to RGB. diff --git a/tensorflow/python/ops/initializers_ns.py b/tensorflow/python/ops/initializers_ns.py index c21079f297..e7996efe93 100644 --- a/tensorflow/python/ops/initializers_ns.py +++ b/tensorflow/python/ops/initializers_ns.py @@ -39,5 +39,8 @@ global_variables = _variables.global_variables_initializer local_variables = _variables.local_variables_initializer # Seal API. +del absolute_import +del division +del print_function del init_ops del _variables diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index 2be2d5a3d4..8343c62816 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -32,15 +32,18 @@ cholesky = linalg_ops.cholesky cholesky_solve = linalg_ops.cholesky_solve det = linalg_ops.matrix_determinant slogdet = gen_linalg_ops.log_matrix_determinant +tf_export('linalg.slogdet')(slogdet) diag = array_ops.matrix_diag diag_part = array_ops.matrix_diag_part eigh = linalg_ops.self_adjoint_eig eigvalsh = linalg_ops.self_adjoint_eigvals einsum = special_math_ops.einsum expm = gen_linalg_ops.matrix_exponential +tf_export('linalg.expm')(expm) eye = linalg_ops.eye inv = linalg_ops.matrix_inverse logm = gen_linalg_ops.matrix_logarithm +tf_export('linalg.logm')(logm) lstsq = linalg_ops.matrix_solve_ls norm = linalg_ops.norm qr = linalg_ops.qr diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py index 91e15b47b9..6d335cdc21 100644 --- a/tensorflow/python/ops/manip_ops.py +++ b/tensorflow/python/ops/manip_ops.py @@ -23,9 +23,11 @@ from __future__ import print_function from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access +@tf_export('manip.roll') def roll(input, shift, axis): # pylint: disable=redefined-builtin return _gen_manip_ops.roll(input, shift, axis) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 0b3509360e..7869cab86e 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1184,11 +1184,16 @@ def floordiv(x, y, name=None): realdiv = gen_math_ops.real_div +tf_export("realdiv")(realdiv) truncatediv = gen_math_ops.truncate_div +tf_export("truncatediv")(truncatediv) # TODO(aselle): Rename this to floordiv when we can. floor_div = gen_math_ops.floor_div +tf_export("floor_div")(floor_div) truncatemod = gen_math_ops.truncate_mod +tf_export("truncatemod")(truncatemod) floormod = gen_math_ops.floor_mod +tf_export("floormod")(floormod) def _mul_dispatch(x, y, name=None): @@ -2111,6 +2116,7 @@ def matmul(a, _OverrideBinaryOperatorHelper(matmul, "matmul") sparse_matmul = gen_math_ops.sparse_mat_mul +tf_export("sparse_matmul")(sparse_matmul) @ops.RegisterStatistics("MatMul", "flops") diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 9d6f65dbbf..47cc4da7f2 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -303,12 +303,12 @@ def _swish_grad(features, grad): # @Defun decorator with noinline=True so that sigmoid(features) is re-computed # during backprop, and we can free the sigmoid(features) expression immediately # after use during the forward pass. +@tf_export("nn.swish") @function.Defun( grad_func=_swish_grad, shape_func=_swish_shape, func_name="swish", noinline=True) -@tf_export("nn.swish") def swish(features): # pylint: disable=g-doc-args """Computes the Swish activation function: `x * sigmoid(x)`. @@ -1343,4 +1343,4 @@ def sampled_softmax_loss(weights, sampled_losses = nn_ops.softmax_cross_entropy_with_logits( labels=labels, logits=logits) # sampled_losses is a [batch_size] tensor. - return sampled_losses
\ No newline at end of file + return sampled_losses diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 7f650ff6a9..cf495970d7 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1385,7 +1385,6 @@ get_variable.__doc__ = get_variable_or_local_docstring % ( "GraphKeys.GLOBAL_VARIABLES") -@functools.wraps(get_variable) @tf_export("get_local_variable") def get_local_variable(*args, **kwargs): kwargs["trainable"] = False diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py index 96219faab7..8141cf92c5 100644 --- a/tensorflow/python/platform/googletest.py +++ b/tensorflow/python/platform/googletest.py @@ -36,6 +36,7 @@ from tensorflow.python.platform import benchmark from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export Benchmark = benchmark.TensorFlowBenchmark # pylint: disable=invalid-name @@ -138,6 +139,7 @@ def StatefulSessionAvailable(): return False +@tf_export('test.StubOutForTesting') class StubOutForTesting(object): """Support class for stubbing methods out for unit testing. diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index 9b7655722a..1660791feb 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -62,6 +62,8 @@ if sys.version_info.major == 2: else: from unittest import mock # pylint: disable=g-import-not-at-top +tf_export('test.mock')(mock) + # Import Benchmark class Benchmark = _googletest.Benchmark # pylint: disable=invalid-name diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index e731127a63..14ce8dbeb3 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -1,5 +1,6 @@ # Description: # Scripts used to generate TensorFlow Python API. + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) @@ -21,7 +22,7 @@ py_binary( srcs = ["create_python_api.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow:tensorflow_py", + "//tensorflow/python", ], ) @@ -80,6 +81,7 @@ genrule( "api/keras/datasets/boston_housing/__init__.py", "api/keras/datasets/cifar10/__init__.py", "api/keras/datasets/cifar100/__init__.py", + "api/keras/datasets/fashion_mnist/__init__.py", "api/keras/datasets/imdb/__init__.py", "api/keras/datasets/mnist/__init__.py", "api/keras/datasets/reuters/__init__.py", @@ -102,6 +104,7 @@ genrule( "api/linalg/__init__.py", "api/logging/__init__.py", "api/losses/__init__.py", + "api/manip/__init__.py", "api/metrics/__init__.py", "api/nn/__init__.py", "api/nn/rnn_cell/__init__.py", @@ -133,7 +136,9 @@ py_library( name = "python_api", srcs = [":python_api_gen"], srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib:contrib_py", # keep + "//tensorflow/python", # keep ], ) diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py index 1557314939..bb7c3e77a3 100644 --- a/tensorflow/tools/api/generator/create_python_api.py +++ b/tensorflow/tools/api/generator/create_python_api.py @@ -23,15 +23,14 @@ import collections import os import sys -# This import is needed so that we can traverse over TensorFlow modules. -import tensorflow as tf # pylint: disable=unused-import +from tensorflow import python as tf from tensorflow.python.util import tf_decorator _API_CONSTANTS_ATTR = '_tf_api_constants' _API_NAMES_ATTR = '_tf_api_names' _API_DIR = '/api/' -_CONTRIB_IMPORT = 'from tensorflow import contrib' +_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api' _GENERATED_FILE_HEADER = """\"\"\"Imports for Python API. This file is MACHINE GENERATED! Do not edit. @@ -92,7 +91,7 @@ def get_api_imports(): if module_contents_name == _API_CONSTANTS_ATTR: for exports, value in attr: for export in exports: - names = ['tf'] + export.split('.') + names = export.split('.') dest_module = '.'.join(names[:-1]) import_str = format_import(module.__name__, value, names[-1]) module_imports[dest_module].append(import_str) @@ -104,29 +103,43 @@ def get_api_imports(): if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__: # The same op might be accessible from multiple modules. # We only want to consider location where function was defined. - if attr.__module__ != module.__name__: + # Here we check if the op is defined in another TensorFlow module in + # sys.modules. + if (hasattr(attr, '__module__') and + attr.__module__.startswith(tf.__name__) and + attr.__module__ != module.__name__ and + attr.__module__ in sys.modules and + module_contents_name in dir(sys.modules[attr.__module__])): continue for export in attr._tf_api_names: # pylint: disable=protected-access - names = ['tf'] + export.split('.') + names = export.split('.') dest_module = '.'.join(names[:-1]) import_str = format_import( module.__name__, module_contents_name, names[-1]) module_imports[dest_module].append(import_str) # Import all required modules in their parent modules. - # For e.g. if we import 'tf.foo.bar.Value'. Then, we also - # import 'bar' in 'tf.foo'. - dest_modules = set(module_imports.keys()) - for dest_module in dest_modules: - dest_module_split = dest_module.split('.') - for dest_submodule_index in range(1, len(dest_module_split)): - dest_submodule = '.'.join(dest_module_split[:dest_submodule_index]) + # For e.g. if we import 'foo.bar.Value'. Then, we also + # import 'bar' in 'foo'. + imported_modules = set(module_imports.keys()) + for module in imported_modules: + if not module: + continue + module_split = module.split('.') + parent_module = '' # we import submodules in their parent_module + + for submodule_index in range(len(module_split)): + import_from = _OUTPUT_MODULE + if submodule_index > 0: + parent_module += ('.' + module_split[submodule_index-1] if parent_module + else module_split[submodule_index-1]) + import_from += '.' + parent_module submodule_import = format_import( - '', dest_module_split[dest_submodule_index], - dest_module_split[dest_submodule_index]) - if submodule_import not in module_imports[dest_submodule]: - module_imports[dest_submodule].append(submodule_import) + import_from, module_split[submodule_index], + module_split[submodule_index]) + if submodule_import not in module_imports[parent_module]: + module_imports[parent_module].append(submodule_import) return module_imports @@ -151,8 +164,8 @@ def create_api_files(output_files): # First get module directory under _API_DIR. module_dir = os.path.dirname( output_file[output_file.rfind(_API_DIR)+len(_API_DIR):]) - # Convert / to . and prefix with tf. - module_name = '.'.join(['tf', module_dir.replace('/', '.')]).strip('.') + # Convert / to . + module_name = module_dir.replace('/', '.').strip('.') module_name_to_file_path[module_name] = output_file # Create file for each expected output in genrule. @@ -162,16 +175,14 @@ def create_api_files(output_files): open(file_path, 'a').close() module_imports = get_api_imports() - module_imports['tf'].append(_CONTRIB_IMPORT) # Include all of contrib. # Add imports to output files. missing_output_files = [] for module, exports in module_imports.items(): # Make sure genrule output file list is in sync with API exports. if module not in module_name_to_file_path: - module_without_tf = module[len('tf.'):] module_file_path = '"api/%s/__init__.py"' % ( - module_without_tf.replace('.', '/')) + module.replace('.', '/')) missing_output_files.append(module_file_path) continue with open(module_name_to_file_path[module], 'w') as fp: diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/tensorflow.initializers.pbtxt index 21a0f84d22..eaf0036cac 100644 --- a/tensorflow/tools/api/golden/tensorflow.initializers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.initializers.pbtxt @@ -1,18 +1,10 @@ path: "tensorflow.initializers" tf_module { member { - name: "absolute_import" - mtype: "<type \'instance\'>" - } - member { name: "constant" mtype: "<type \'type\'>" } member { - name: "division" - mtype: "<type \'instance\'>" - } - member { name: "identity" mtype: "<type \'type\'>" } @@ -25,10 +17,6 @@ tf_module { mtype: "<type \'type\'>" } member { - name: "print_function" - mtype: "<type \'instance\'>" - } - member { name: "random_normal" mtype: "<type \'type\'>" } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt index 791cfda233..a0e14356fa 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt @@ -1,3 +1,7 @@ path: "tensorflow.keras.datasets.fashion_mnist" tf_module { + member_method { + name: "load_data" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 608a34ab7b..15bf1abb5f 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -23,6 +23,7 @@ py_test( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow:experimental_tensorflow_py", "//tensorflow:tensorflow_py", "//tensorflow/python:client_testlib", "//tensorflow/python:lib", diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 2a784973e1..40b152d587 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -34,6 +34,7 @@ import sys import unittest import tensorflow as tf +from tensorflow import experimental_api as api from google.protobuf import text_format @@ -46,6 +47,9 @@ from tensorflow.tools.api.lib import python_object_to_proto_visitor from tensorflow.tools.common import public_api from tensorflow.tools.common import traverse +if hasattr(tf, 'experimental_api'): + del tf.experimental_api + # FLAGS defined at the bottom: FLAGS = None # DEFINE_boolean, update_goldens, default False: @@ -109,7 +113,8 @@ class ApiCompatibilityTest(test.TestCase): expected_dict, actual_dict, verbose=False, - update_goldens=False): + update_goldens=False, + additional_missing_object_message=''): """Diff given dicts of protobufs and report differences a readable way. Args: @@ -120,6 +125,8 @@ class ApiCompatibilityTest(test.TestCase): verbose: Whether to log the full diffs, or simply report which files were different. update_goldens: Whether to update goldens when there are diffs found. + additional_missing_object_message: Message to print when a symbol is + missing. """ diffs = [] verbose_diffs = [] @@ -138,7 +145,8 @@ class ApiCompatibilityTest(test.TestCase): verbose_diff_message = '' # First check if the key is not found in one or the other. if key in only_in_expected: - diff_message = 'Object %s expected but not found (removed).' % key + diff_message = 'Object %s expected but not found (removed). %s' % ( + key, additional_missing_object_message) verbose_diff_message = diff_message elif key in only_in_actual: diff_message = 'New object %s found (added).' % key @@ -229,6 +237,64 @@ class ApiCompatibilityTest(test.TestCase): verbose=FLAGS.verbose_diffs, update_goldens=FLAGS.update_goldens) + @unittest.skipUnless( + sys.version_info.major == 2, + 'API compabitility test goldens are generated using python2.') + def testNewAPIBackwardsCompatibility(self): + # Extract all API stuff. + visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() + + public_api_visitor = public_api.PublicAPIVisitor(visitor) + public_api_visitor.do_not_descend_map['tf'].append('contrib') + public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] + # TODO(annarev): these symbols have been added recently with tf_export + # decorators, but they are not exported with old API. Export them using + # old API approach and remove them from here. + public_api_visitor.private_map['tf'] = [ + 'to_complex128', 'to_complex64', 'add_to_collections', + 'unsorted_segment_mean'] + traverse.traverse(api, public_api_visitor) + + proto_dict = visitor.GetProtos() + + # Read all golden files. + expression = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + _KeyToFilePath('*')) + golden_file_list = file_io.get_matching_files(expression) + + def _ReadFileToProto(filename): + """Read a filename, create a protobuf from its contents.""" + ret_val = api_objects_pb2.TFAPIObject() + text_format.Merge(file_io.read_file_to_string(filename), ret_val) + return ret_val + + golden_proto_dict = { + _FileNameToKey(filename): _ReadFileToProto(filename) + for filename in golden_file_list + } + + # user_ops is an empty module. It is currently available in TensorFlow API + # but we don't keep empty modules in the new API. + # We delete user_ops from golden_proto_dict to make sure assert passes + # when diffing new API against goldens. + # TODO(annarev): remove user_ops from goldens once we switch to new API. + tf_module = golden_proto_dict['tensorflow'].tf_module + for i in range(len(tf_module.member)): + if tf_module.member[i].name == 'user_ops': + del tf_module.member[i] + break + + # Diff them. Do not fail if called with update. + # If the test is run to update goldens, only report diffs but do not fail. + self._AssertProtoDictEquals( + golden_proto_dict, + proto_dict, + verbose=FLAGS.verbose_diffs, + update_goldens=False, + additional_missing_object_message= + 'Check if tf_export decorator/call is missing for this symbol.') + if __name__ == '__main__': parser = argparse.ArgumentParser() |