aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-06-07 12:05:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 12:14:50 -0700
commit501cf726cbee2ee13efef43884a6552ca211979d (patch)
tree2a93bae901b9f9d32f5d622e2e4d626668b48b99
parent4d0d60a82c52c6c71650db33bf826f03559d91fc (diff)
Internal Change.
PiperOrigin-RevId: 199673803
-rw-r--r--tensorflow/BUILD7
-rw-r--r--tensorflow/api_template.__init__.py17
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake45
-rw-r--r--tensorflow/python/estimator/BUILD4
-rw-r--r--tensorflow/python/estimator/api/BUILD17
-rw-r--r--tensorflow/python/estimator/canned/baseline.py6
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py6
-rw-r--r--tensorflow/python/estimator/canned/dnn.py6
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py6
-rw-r--r--tensorflow/python/estimator/canned/linear.py6
-rw-r--r--tensorflow/python/estimator/canned/parsing_utils.py6
-rw-r--r--tensorflow/python/estimator/estimator.py12
-rw-r--r--tensorflow/python/estimator/export/export.py10
-rw-r--r--tensorflow/python/estimator/export/export_output.py10
-rw-r--r--tensorflow/python/estimator/exporter.py10
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io.py4
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io.py4
-rw-r--r--tensorflow/python/estimator/model_fn.py6
-rw-r--r--tensorflow/python/estimator/run_config.py4
-rw-r--r--tensorflow/python/estimator/training.py8
-rw-r--r--tensorflow/python/util/tf_export.py58
-rw-r--r--tensorflow/python/util/tf_export_test.py7
-rw-r--r--tensorflow/tools/api/generator/api_gen.bzl20
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py35
-rw-r--r--tensorflow/tools/api/generator/create_python_api_test.py9
25 files changed, 218 insertions, 105 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index e0bce820d1..a73c4ca3aa 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -541,14 +541,17 @@ exports_files(
)
gen_api_init_files(
- name = "python_api_gen",
+ name = "tensorflow_python_api_gen",
srcs = ["api_template.__init__.py"],
root_init_template = "api_template.__init__.py",
)
py_library(
name = "tensorflow_py",
- srcs = [":python_api_gen"],
+ srcs = [
+ ":tensorflow_python_api_gen",
+ "//tensorflow/python/estimator/api:estimator_python_api_gen",
+ ],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 9b0d7d48af..9662d7b478 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -22,7 +22,22 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# API IMPORTS PLACEHOLDER
-from tensorflow.python.util.lazy_loader import LazyLoader
+try:
+ import os # pylint: disable=g-import-not-at-top
+ # Add `estimator` attribute to allow access to estimator APIs via
+ # "tf.estimator..."
+ from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top
+
+ # Add `estimator` to the __path__ to allow "from tensorflow.estimator..."
+ # style imports.
+ from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top
+ __path__ += [os.path.dirname(estimator_api.__file__)]
+ del estimator_api
+ del os
+except (ImportError, AttributeError):
+ print('tf.estimator package not installed.')
+
+from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index d019dd48f2..a0c3ddd28b 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -756,6 +756,8 @@ add_custom_command(
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
+ "--package=tensorflow.python"
+ "--apiname=tensorflow"
"${api_init_list_file}"
COMMENT "Generating __init__.py files for Python API."
@@ -765,7 +767,49 @@ add_custom_command(
add_custom_target(tf_python_api SOURCES ${api_init_files})
add_dependencies(tf_python_api tf_python_ops)
+# TODO(mikecase): This can be removed once tf.estimator is moved
+# out of TensorFlow.
+########################################################
+# Generate API __init__.py files for tf.estimator.
+########################################################
+
+# Parse tensorflow/tools/api/generator/BUILD to get list of generated files.
+FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text)
+STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text})
+string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text})
+string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text})
+string(REPLACE "," ";" api_init_files_list ${api_init_files_text})
+
+set(api_init_files "")
+foreach(api_init_file ${api_init_files_list})
+ string(STRIP "${api_init_file}" api_init_file)
+ if(api_init_file)
+ string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes
+ list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api/${api_init_file}")
+ endif()
+endforeach(api_init_file)
+set(estimator_api_init_list_file "${tensorflow_source_dir}/estimator_api_init_files_list.txt")
+file(WRITE "${estimator_api_init_list_file}" "${api_init_files}")
+
+# Run create_python_api.py to generate __init__.py files.
+add_custom_command(
+ OUTPUT ${api_init_files}
+ DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
+
+ # Run create_python_api.py to generate API init files.
+ COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
+ "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api"
+ "--package=tensorflow.python.estimator"
+ "--apiname=estimator"
+ "${estimator_api_init_list_file}"
+
+ COMMENT "Generating __init__.py files for Python API."
+ WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
+)
+add_custom_target(estimator_python_api SOURCES ${api_init_files})
+add_dependencies(estimator_python_api tf_python_ops)
############################################################
# Build a PIP package containing the TensorFlow runtime.
############################################################
@@ -776,6 +820,7 @@ add_dependencies(tf_python_build_pip_package
tf_python_touchup_modules
tf_python_ops
tf_python_api
+ estimator_python_api
tf_extension_ops)
# Fix-up Python files that were not included by the add_python_module() macros.
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index d538c6c415..c0d63b79a6 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -12,6 +12,10 @@ py_library(
name = "estimator_py",
srcs = ["estimator_lib.py"],
srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
":baseline",
":boosted_trees",
diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD
new file mode 100644
index 0000000000..cddee9b8f3
--- /dev/null
+++ b/tensorflow/python/estimator/api/BUILD
@@ -0,0 +1,17 @@
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/tools/api/generator:api_gen.bzl", "gen_api_init_files")
+load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
+
+gen_api_init_files(
+ name = "estimator_python_api_gen",
+ api_name = "estimator",
+ output_files = ESTIMATOR_API_INIT_FILES,
+ package = "tensorflow.python.estimator",
+)
diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py
index 980c057372..3c6816cb03 100644
--- a/tensorflow/python/estimator/canned/baseline.py
+++ b/tensorflow/python/estimator/canned/baseline.py
@@ -59,7 +59,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.3 is a historical artifact of the initial
# implementation, but seems a reasonable choice.
@@ -174,7 +174,7 @@ def _baseline_model_fn(features, labels, mode, head, optimizer,
train_op_fn=train_op_fn)
-@tf_export('estimator.BaselineClassifier')
+@estimator_export('estimator.BaselineClassifier')
class BaselineClassifier(estimator.Estimator):
"""A classifier that can establish a simple baseline.
@@ -277,7 +277,7 @@ class BaselineClassifier(estimator.Estimator):
config=config)
-@tf_export('estimator.BaselineRegressor')
+@estimator_export('estimator.BaselineRegressor')
class BaselineRegressor(estimator.Estimator):
"""A regressor that can establish a simple baseline.
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 4e6010a162..6b54f51ca6 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -39,7 +39,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# TODO(nponomareva): Reveal pruning params here.
_TreeHParams = collections.namedtuple('TreeHParams', [
@@ -712,7 +712,7 @@ def _create_regression_head(label_dimension, weight_column=None):
# pylint: enable=protected-access
-@tf_export('estimator.BoostedTreesClassifier')
+@estimator_export('estimator.BoostedTreesClassifier')
class BoostedTreesClassifier(estimator.Estimator):
"""A Classifier for Tensorflow Boosted Trees models."""
@@ -830,7 +830,7 @@ class BoostedTreesClassifier(estimator.Estimator):
model_fn=_model_fn, model_dir=model_dir, config=config)
-@tf_export('estimator.BoostedTreesRegressor')
+@estimator_export('estimator.BoostedTreesRegressor')
class BoostedTreesRegressor(estimator.Estimator):
"""A Regressor for Tensorflow Boosted Trees models."""
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 1feac36f35..b924ad5df4 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -32,7 +32,7 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.05 is a historical artifact of the initial
# implementation, but seems a reasonable choice.
@@ -201,7 +201,7 @@ def _dnn_model_fn(features,
logits=logits)
-@tf_export('estimator.DNNClassifier')
+@estimator_export('estimator.DNNClassifier')
class DNNClassifier(estimator.Estimator):
"""A classifier for TensorFlow DNN models.
@@ -353,7 +353,7 @@ class DNNClassifier(estimator.Estimator):
warm_start_from=warm_start_from)
-@tf_export('estimator.DNNRegressor')
+@estimator_export('estimator.DNNRegressor')
class DNNRegressor(estimator.Estimator):
"""A regressor for TensorFlow DNN models.
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 95efc0a028..64d81c46ce 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -37,7 +37,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rates are a historical artifact of the initial
# implementation.
@@ -225,7 +225,7 @@ def _dnn_linear_combined_model_fn(features,
logits=logits)
-@tf_export('estimator.DNNLinearCombinedClassifier')
+@estimator_export('estimator.DNNLinearCombinedClassifier')
class DNNLinearCombinedClassifier(estimator.Estimator):
"""An estimator for TensorFlow Linear and DNN joined classification models.
@@ -406,7 +406,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
warm_start_from=warm_start_from)
-@tf_export('estimator.DNNLinearCombinedRegressor')
+@estimator_export('estimator.DNNLinearCombinedRegressor')
class DNNLinearCombinedRegressor(estimator.Estimator):
"""An estimator for TensorFlow Linear and DNN joined models for regression.
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 81657f0c01..705fc3ce06 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -33,7 +33,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import ftrl
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.2 is a historical artifact of the initial
@@ -164,7 +164,7 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
logits=logits)
-@tf_export('estimator.LinearClassifier')
+@estimator_export('estimator.LinearClassifier')
class LinearClassifier(estimator.Estimator):
"""Linear classifier model.
@@ -317,7 +317,7 @@ class LinearClassifier(estimator.Estimator):
warm_start_from=warm_start_from)
-@tf_export('estimator.LinearRegressor')
+@estimator_export('estimator.LinearRegressor')
class LinearRegressor(estimator.Estimator):
"""An estimator for TensorFlow Linear regression problems.
diff --git a/tensorflow/python/estimator/canned/parsing_utils.py b/tensorflow/python/estimator/canned/parsing_utils.py
index 74e5e5a1be..1ae0f1e9f7 100644
--- a/tensorflow/python/estimator/canned/parsing_utils.py
+++ b/tensorflow/python/estimator/canned/parsing_utils.py
@@ -23,10 +23,10 @@ import six
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.classifier_parse_example_spec')
+@estimator_export('estimator.classifier_parse_example_spec')
def classifier_parse_example_spec(feature_columns,
label_key,
label_dtype=dtypes.int64,
@@ -166,7 +166,7 @@ def classifier_parse_example_spec(feature_columns,
return parsing_spec
-@tf_export('estimator.regressor_parse_example_spec')
+@estimator_export('estimator.regressor_parse_example_spec')
def regressor_parse_example_spec(feature_columns,
label_key,
label_dtype=dtypes.float32,
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 4be1af1e66..41c25f1c73 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -66,14 +66,14 @@ from tensorflow.python.util import compat
from tensorflow.python.util import compat_internal
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_VALID_MODEL_FN_ARGS = set(
['features', 'labels', 'mode', 'params', 'self', 'config'])
-@tf_export('estimator.Estimator')
+@estimator_export('estimator.Estimator')
class Estimator(object):
"""Estimator class to train and evaluate TensorFlow models.
@@ -566,7 +566,8 @@ class Estimator(object):
allowed_overrides = set([
'_call_input_fn', '_create_global_step',
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_tf_api_names', '_validate_features_in_predict_input',
+ '_tf_api_names', '_estimator_api_names', '_estimator_api_constants',
+ '_validate_features_in_predict_input',
'_call_model_fn', '_add_meta_graph_for_mode'
])
estimator_members = set([m for m in Estimator.__dict__.keys()
@@ -1634,11 +1635,12 @@ def _has_dataset_or_queue_runner(maybe_tensor):
# Now, check queue.
return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
+
VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name
-tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo)
+estimator_export('estimator.VocabInfo')(VocabInfo)
-@tf_export('estimator.WarmStartSettings')
+@estimator_export('estimator.WarmStartSettings')
class WarmStartSettings(
collections.namedtuple('WarmStartSettings', [
'ckpt_to_initialize_from',
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index ff19a0a7f4..010c0f3f59 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util import compat
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
@@ -93,7 +93,7 @@ def _check_tensor_key(name, error_label='feature'):
raise ValueError('{} keys must be strings: {}.'.format(error_label, name))
-@tf_export('estimator.export.ServingInputReceiver')
+@estimator_export('estimator.export.ServingInputReceiver')
class ServingInputReceiver(
collections.namedtuple(
'ServingInputReceiver',
@@ -161,7 +161,7 @@ class ServingInputReceiver(
receiver_tensors_alternatives=receiver_tensors_alternatives)
-@tf_export('estimator.export.TensorServingInputReceiver')
+@estimator_export('estimator.export.TensorServingInputReceiver')
class TensorServingInputReceiver(
collections.namedtuple(
'TensorServingInputReceiver',
@@ -263,7 +263,7 @@ class SupervisedInputReceiver(
receiver_tensors=receiver_tensors)
-@tf_export('estimator.export.build_parsing_serving_input_receiver_fn')
+@estimator_export('estimator.export.build_parsing_serving_input_receiver_fn')
def build_parsing_serving_input_receiver_fn(feature_spec,
default_batch_size=None):
"""Build a serving_input_receiver_fn expecting fed tf.Examples.
@@ -313,7 +313,7 @@ def _placeholders_from_receiver_tensors_dict(input_vals,
}
-@tf_export('estimator.export.build_raw_serving_input_receiver_fn')
+@estimator_export('estimator.export.build_raw_serving_input_receiver_fn')
def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
"""Build a serving_input_receiver_fn expecting feature Tensors.
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index d387ea2940..6c26d29985 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -26,10 +26,10 @@ import six
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_def_utils
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.export.ExportOutput')
+@estimator_export('estimator.export.ExportOutput')
class ExportOutput(object):
"""Represents an output of a model that can be served.
@@ -100,7 +100,7 @@ class ExportOutput(object):
return output_dict
-@tf_export('estimator.export.ClassificationOutput')
+@estimator_export('estimator.export.ClassificationOutput')
class ClassificationOutput(ExportOutput):
"""Represents the output of a classification head.
@@ -169,7 +169,7 @@ class ClassificationOutput(ExportOutput):
examples, self.classes, self.scores)
-@tf_export('estimator.export.RegressionOutput')
+@estimator_export('estimator.export.RegressionOutput')
class RegressionOutput(ExportOutput):
"""Represents the output of a regression head."""
@@ -202,7 +202,7 @@ class RegressionOutput(ExportOutput):
return signature_def_utils.regression_signature_def(examples, self.value)
-@tf_export('estimator.export.PredictOutput')
+@estimator_export('estimator.export.PredictOutput')
class PredictOutput(ExportOutput):
"""Represents the output of a generic prediction head.
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py
index 5981fa59b7..7cdf840c97 100644
--- a/tensorflow/python/estimator/exporter.py
+++ b/tensorflow/python/estimator/exporter.py
@@ -28,10 +28,10 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging
from tensorflow.python.summary import summary_iterator
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.Exporter')
+@estimator_export('estimator.Exporter')
class Exporter(object):
"""A class representing a type of model export."""
@@ -172,7 +172,7 @@ def _verify_compre_fn_args(compare_fn):
(compare_fn, non_valid_args))
-@tf_export('estimator.BestExporter')
+@estimator_export('estimator.BestExporter')
class BestExporter(Exporter):
"""This class exports the serving graph and checkpoints of the best models.
@@ -367,7 +367,7 @@ class BestExporter(Exporter):
return best_eval_result
-@tf_export('estimator.FinalExporter')
+@estimator_export('estimator.FinalExporter')
class FinalExporter(Exporter):
"""This class exports the serving graph and checkpoints in the end.
@@ -418,7 +418,7 @@ class FinalExporter(Exporter):
is_the_final_export)
-@tf_export('estimator.LatestExporter')
+@estimator_export('estimator.LatestExporter')
class LatestExporter(Exporter):
"""This class regularly exports the serving graph and checkpoints.
diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py
index a6f4712910..035c7c148c 100644
--- a/tensorflow/python/estimator/inputs/numpy_io.py
+++ b/tensorflow/python/estimator/inputs/numpy_io.py
@@ -24,7 +24,7 @@ import numpy as np
from six import string_types
from tensorflow.python.estimator.inputs.queues import feeding_functions
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# Key name to pack the target into dict of `features`. See
# `_get_unique_target_key` for details.
@@ -87,7 +87,7 @@ def _validate_and_convert_features(x):
return ordered_dict_data
-@tf_export('estimator.inputs.numpy_input_fn')
+@estimator_export('estimator.inputs.numpy_input_fn')
def numpy_input_fn(x,
y=None,
batch_size=128,
diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py
index bd06843021..938e244fb3 100644
--- a/tensorflow/python/estimator/inputs/pandas_io.py
+++ b/tensorflow/python/estimator/inputs/pandas_io.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.estimator.inputs.queues import feeding_functions
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
try:
# pylint: disable=g-import-not-at-top
@@ -35,7 +35,7 @@ except ImportError:
HAS_PANDAS = False
-@tf_export('estimator.inputs.pandas_input_fn')
+@estimator_export('estimator.inputs.pandas_input_fn')
def pandas_input_fn(x,
y=None,
batch_size=128,
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 3edf9fe940..c60c7f63ba 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -32,10 +32,10 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import nest
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.ModeKeys')
+@estimator_export('estimator.ModeKeys')
class ModeKeys(object):
"""Standard names for model modes.
@@ -62,7 +62,7 @@ EXPORT_TAG_MAP = {
}
-@tf_export('estimator.EstimatorSpec')
+@estimator_export('estimator.EstimatorSpec')
class EstimatorSpec(
collections.namedtuple('EstimatorSpec', [
'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops',
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index c7707be839..b948ce96e0 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat_internal
from tensorflow.python.util import function_utils
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_USE_DEFAULT = object()
@@ -296,7 +296,7 @@ class TaskType(object):
EVALUATOR = 'evaluator'
-@tf_export('estimator.RunConfig')
+@estimator_export('estimator.RunConfig')
class RunConfig(object):
"""This class specifies the configurations for an `Estimator` run."""
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index fb6a68b4f7..1572af579b 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -35,7 +35,7 @@ from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_MAX_DELAY_SECS = 60
_DELAY_SECS_PER_WORKER = 5
@@ -115,7 +115,7 @@ def _is_google_env():
return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE
-@tf_export('estimator.TrainSpec')
+@estimator_export('estimator.TrainSpec')
class TrainSpec(
collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])):
"""Configuration for the "train" part for the `train_and_evaluate` call.
@@ -167,7 +167,7 @@ class TrainSpec(
cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)
-@tf_export('estimator.EvalSpec')
+@estimator_export('estimator.EvalSpec')
class EvalSpec(
collections.namedtuple('EvalSpec', [
'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',
@@ -263,7 +263,7 @@ class EvalSpec(
throttle_secs=throttle_secs)
-@tf_export('estimator.train_and_evaluate')
+@estimator_export('estimator.train_and_evaluate')
def train_and_evaluate(estimator, train_spec, eval_spec):
"""Train and evaluate the `estimator`.
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index bf3961c692..e154ffb68a 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -41,17 +41,35 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+import functools
import sys
from tensorflow.python.util import tf_decorator
+ESTIMATOR_API_NAME = 'estimator'
+TENSORFLOW_API_NAME = 'tensorflow'
+
+_Attributes = collections.namedtuple(
+ 'ExportedApiAttributes', ['names', 'constants'])
+
+# Attribute values must be unique to each API.
+API_ATTRS = {
+ TENSORFLOW_API_NAME: _Attributes(
+ '_tf_api_names',
+ '_tf_api_constants'),
+ ESTIMATOR_API_NAME: _Attributes(
+ '_estimator_api_names',
+ '_estimator_api_constants')
+}
+
class SymbolAlreadyExposedError(Exception):
"""Raised when adding API names to symbol that already has API names."""
pass
-class tf_export(object): # pylint: disable=invalid-name
+class api_export(object): # pylint: disable=invalid-name
"""Provides ways to export symbols to the TensorFlow API."""
def __init__(self, *args, **kwargs):
@@ -63,15 +81,12 @@ class tf_export(object): # pylint: disable=invalid-name
overrides: List of symbols that this is overriding
(those overrided api exports will be removed). Note: passing overrides
has no effect on exporting a constant.
- allow_multiple_exports: Allows exporting the same symbol multiple
- times with multiple `tf_export` usages. Prefer however, to list all
- of the exported names in a single `tf_export` usage when possible.
-
+ api_name: Name of the API you want to generate (e.g. `tensorflow` or
+ `estimator`). Default is `tensorflow`.
"""
self._names = args
+ self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
- self._allow_multiple_exports = kwargs.get(
- 'allow_multiple_exports', False)
def __call__(self, func):
"""Calls this decorator.
@@ -86,25 +101,24 @@ class tf_export(object): # pylint: disable=invalid-name
SymbolAlreadyExposedError: Raised when a symbol already has API names
and kwarg `allow_multiple_exports` not set.
"""
+ api_names_attr = API_ATTRS[self._api_name].names
+
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
- del undecorated_f._tf_api_names # pylint: disable=protected-access
+ delattr(undecorated_f, api_names_attr)
_, undecorated_func = tf_decorator.unwrap(func)
# Check for an existing api. We check if attribute name is in
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
- if '_tf_api_names' in undecorated_func.__dict__:
- if self._allow_multiple_exports:
- undecorated_func._tf_api_names += self._names # pylint: disable=protected-access
- else:
- raise SymbolAlreadyExposedError(
- 'Symbol %s is already exposed as %s.' %
- (undecorated_func.__name__, undecorated_func._tf_api_names)) # pylint: disable=protected-access
- else:
- undecorated_func._tf_api_names = self._names # pylint: disable=protected-access
+ if api_names_attr in undecorated_func.__dict__:
+ raise SymbolAlreadyExposedError(
+ 'Symbol %s is already exposed as %s.' %
+ (undecorated_func.__name__, getattr(
+ undecorated_func, api_names_attr))) # pylint: disable=protected-access
+ setattr(undecorated_func, api_names_attr, self._names)
return func
def export_constant(self, module_name, name):
@@ -126,8 +140,12 @@ class tf_export(object): # pylint: disable=invalid-name
name: (string) Current constant name.
"""
module = sys.modules[module_name]
- if not hasattr(module, '_tf_api_constants'):
- module._tf_api_constants = [] # pylint: disable=protected-access
+ if not hasattr(module, API_ATTRS[self._api_name].constants):
+ setattr(module, API_ATTRS[self._api_name].constants, [])
# pylint: disable=protected-access
- module._tf_api_constants.append((self._names, name))
+ getattr(module, API_ATTRS[self._api_name].constants).append(
+ (self._names, name))
+
+tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
+estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME)
diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py
index ace3f054ba..b9e26ecb33 100644
--- a/tensorflow/python/util/tf_export_test.py
+++ b/tensorflow/python/util/tf_export_test.py
@@ -128,13 +128,6 @@ class ValidateExportTest(test.TestCase):
with self.assertRaises(tf_export.SymbolAlreadyExposedError):
export_decorator(_test_function)
- def testEAllowMultipleExports(self):
- _test_function._tf_api_names = ['name1', 'name2']
- tf_export.tf_export('nameRed', 'nameBlue', allow_multiple_exports=True)(
- _test_function)
- self.assertEquals(['name1', 'name2', 'nameRed', 'nameBlue'],
- _test_function._tf_api_names)
-
def testOverridesFunction(self):
_test_function2._tf_api_names = ['abc']
diff --git a/tensorflow/tools/api/generator/api_gen.bzl b/tensorflow/tools/api/generator/api_gen.bzl
index fe3e4d1434..41713a94ec 100644
--- a/tensorflow/tools/api/generator/api_gen.bzl
+++ b/tensorflow/tools/api/generator/api_gen.bzl
@@ -11,9 +11,6 @@ TENSORFLOW_API_INIT_FILES = [
"distributions/__init__.py",
"distributions/bijectors/__init__.py",
"errors/__init__.py",
- "estimator/__init__.py",
- "estimator/export/__init__.py",
- "estimator/inputs/__init__.py",
"feature_column/__init__.py",
"gfile/__init__.py",
"graph_util/__init__.py",
@@ -91,6 +88,16 @@ TENSORFLOW_API_INIT_FILES = [
# END GENERATED FILES
]
+# keep sorted
+ESTIMATOR_API_INIT_FILES = [
+ # BEGIN GENERATED ESTIMATOR FILES
+ "__init__.py",
+ "estimator/__init__.py",
+ "estimator/export/__init__.py",
+ "estimator/inputs/__init__.py",
+ # END GENERATED ESTIMATOR FILES
+]
+
# Creates a genrule that generates a directory structure with __init__.py
# files that import all exported modules (i.e. modules with tf_export
# decorators).
@@ -110,7 +117,9 @@ TENSORFLOW_API_INIT_FILES = [
def gen_api_init_files(name,
output_files=TENSORFLOW_API_INIT_FILES,
root_init_template=None,
- srcs=[]):
+ srcs=[],
+ api_name="tensorflow",
+ package="tensorflow.python"):
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
@@ -119,7 +128,8 @@ def gen_api_init_files(name,
outs = output_files,
cmd = (
"$(location //tensorflow/tools/api/generator:create_python_api) " +
- root_init_template_flag + " --apidir=$(@D) $(OUTS)"),
+ root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"),
srcs = srcs,
tools = ["//tensorflow/tools/api/generator:create_python_api"],
+ visibility = ["//tensorflow:__pkg__"],
)
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
index de0a50ab44..972bdc84ae 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -25,10 +25,10 @@ import os
import sys
from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_export
+API_ATTRS = tf_export.API_ATTRS
-_API_CONSTANTS_ATTR = '_tf_api_constants'
-_API_NAMES_ATTR = '_tf_api_names'
_DEFAULT_PACKAGE = 'tensorflow.python'
_GENFILES_DIR_SUFFIX = 'genfiles/'
_SYMBOLS_TO_SKIP_EXPLICITLY = {
@@ -154,12 +154,13 @@ __all__.extend([_s for _s in _names_with_underscore])
return module_text_map
-def get_api_init_text(package):
+def get_api_init_text(package, api_name):
"""Get a map from destination module to __init__.py code for that module.
Args:
package: Base python package containing python with target tf_export
decorators.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
Returns:
A dictionary where
@@ -187,7 +188,7 @@ def get_api_init_text(package):
attr = getattr(module, module_contents_name)
# If attr is _tf_api_constants attribute, then add the constants.
- if module_contents_name == _API_CONSTANTS_ATTR:
+ if module_contents_name == API_ATTRS[api_name].constants:
for exports, value in attr:
for export in exports:
names = export.split('.')
@@ -196,15 +197,12 @@ def get_api_init_text(package):
-1, dest_module, module.__name__, value, names[-1])
continue
- try:
- _, attr = tf_decorator.unwrap(attr)
- except Exception as e:
- print('5555: %s %s' % (module, module_contents_name), file=sys.stderr)
- raise e
+ _, attr = tf_decorator.unwrap(attr)
# If attr is a symbol with _tf_api_names attribute, then
# add import for it.
- if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
- for export in attr._tf_api_names: # pylint: disable=protected-access
+ if (hasattr(attr, '__dict__') and
+ API_ATTRS[api_name].names in attr.__dict__):
+ for export in getattr(attr, API_ATTRS[api_name].names): # pylint: disable=protected-access
names = export.split('.')
dest_module = '.'.join(names[:-1])
module_code_builder.add_import(
@@ -241,7 +239,7 @@ def get_module(dir_path, relative_to_dir):
relative_to_dir: Get module relative to this directory.
Returns:
- module that corresponds to the given directory.
+ Name of module that corresponds to the given directory.
"""
dir_path = dir_path[len(relative_to_dir):]
# Convert path separators to '/' for easier parsing below.
@@ -250,7 +248,7 @@ def get_module(dir_path, relative_to_dir):
def create_api_files(
- output_files, package, root_init_template, output_dir):
+ output_files, package, root_init_template, output_dir, api_name):
"""Creates __init__.py files for the Python API.
Args:
@@ -262,6 +260,7 @@ def create_api_files(
"#API IMPORTS PLACEHOLDER" comment in the template file will be replaced
with imports.
output_dir: output API root directory.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
Raises:
ValueError: if an output file is not under api/ directory,
@@ -278,7 +277,7 @@ def create_api_files(
os.makedirs(os.path.dirname(file_path))
open(file_path, 'a').close()
- module_text_map = get_api_init_text(package)
+ module_text_map = get_api_init_text(package, api_name)
# Add imports to output files.
missing_output_files = []
@@ -329,6 +328,10 @@ def main():
help='Directory where generated output files are placed. '
'gendir should be a prefix of apidir. Also, apidir '
'should be a prefix of every directory in outputs.')
+ parser.add_argument(
+ '--apiname', required=True, type=str,
+ choices=API_ATTRS.keys(),
+ help='The API you want to generate.')
args = parser.parse_args()
@@ -342,8 +345,8 @@ def main():
# Populate `sys.modules` with modules containing tf_export().
importlib.import_module(args.package)
- create_api_files(
- outputs, args.package, args.root_init_template, args.apidir)
+ create_api_files(outputs, args.package, args.root_init_template,
+ args.apidir, args.apiname)
if __name__ == '__main__':
diff --git a/tensorflow/tools/api/generator/create_python_api_test.py b/tensorflow/tools/api/generator/create_python_api_test.py
index 986340cf6d..651ec9d040 100644
--- a/tensorflow/tools/api/generator/create_python_api_test.py
+++ b/tensorflow/tools/api/generator/create_python_api_test.py
@@ -57,7 +57,8 @@ class CreatePythonApiTest(test.TestCase):
def testFunctionImportIsAdded(self):
imports = create_python_api.get_api_init_text(
- package=create_python_api._DEFAULT_PACKAGE)
+ package=create_python_api._DEFAULT_PACKAGE,
+ api_name='tensorflow')
expected_import = (
'from tensorflow.python.test_module '
'import test_op as test_op1')
@@ -73,7 +74,8 @@ class CreatePythonApiTest(test.TestCase):
def testClassImportIsAdded(self):
imports = create_python_api.get_api_init_text(
- package=create_python_api._DEFAULT_PACKAGE)
+ package=create_python_api._DEFAULT_PACKAGE,
+ api_name='tensorflow')
expected_import = ('from tensorflow.python.test_module '
'import TestClass')
self.assertTrue(
@@ -82,7 +84,8 @@ class CreatePythonApiTest(test.TestCase):
def testConstantIsAdded(self):
imports = create_python_api.get_api_init_text(
- package=create_python_api._DEFAULT_PACKAGE)
+ package=create_python_api._DEFAULT_PACKAGE,
+ api_name='tensorflow')
expected = ('from tensorflow.python.test_module '
'import _TEST_CONSTANT')
self.assertTrue(expected in str(imports),