aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/estimator/BUILD4
-rw-r--r--tensorflow/python/estimator/api/BUILD17
-rw-r--r--tensorflow/python/estimator/canned/baseline.py6
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py6
-rw-r--r--tensorflow/python/estimator/canned/dnn.py6
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py6
-rw-r--r--tensorflow/python/estimator/canned/linear.py6
-rw-r--r--tensorflow/python/estimator/canned/parsing_utils.py6
-rw-r--r--tensorflow/python/estimator/estimator.py12
-rw-r--r--tensorflow/python/estimator/export/export.py10
-rw-r--r--tensorflow/python/estimator/export/export_output.py10
-rw-r--r--tensorflow/python/estimator/exporter.py10
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io.py4
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io.py4
-rw-r--r--tensorflow/python/estimator/model_fn.py6
-rw-r--r--tensorflow/python/estimator/run_config.py4
-rw-r--r--tensorflow/python/estimator/training.py8
-rw-r--r--tensorflow/python/util/tf_export.py58
-rw-r--r--tensorflow/python/util/tf_export_test.py7
19 files changed, 112 insertions, 78 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index d538c6c415..c0d63b79a6 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -12,6 +12,10 @@ py_library(
name = "estimator_py",
srcs = ["estimator_lib.py"],
srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
":baseline",
":boosted_trees",
diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD
new file mode 100644
index 0000000000..cddee9b8f3
--- /dev/null
+++ b/tensorflow/python/estimator/api/BUILD
@@ -0,0 +1,17 @@
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/tools/api/generator:api_gen.bzl", "gen_api_init_files")
+load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
+
+gen_api_init_files(
+ name = "estimator_python_api_gen",
+ api_name = "estimator",
+ output_files = ESTIMATOR_API_INIT_FILES,
+ package = "tensorflow.python.estimator",
+)
diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py
index 980c057372..3c6816cb03 100644
--- a/tensorflow/python/estimator/canned/baseline.py
+++ b/tensorflow/python/estimator/canned/baseline.py
@@ -59,7 +59,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.3 is a historical artifact of the initial
# implementation, but seems a reasonable choice.
@@ -174,7 +174,7 @@ def _baseline_model_fn(features, labels, mode, head, optimizer,
train_op_fn=train_op_fn)
-@tf_export('estimator.BaselineClassifier')
+@estimator_export('estimator.BaselineClassifier')
class BaselineClassifier(estimator.Estimator):
"""A classifier that can establish a simple baseline.
@@ -277,7 +277,7 @@ class BaselineClassifier(estimator.Estimator):
config=config)
-@tf_export('estimator.BaselineRegressor')
+@estimator_export('estimator.BaselineRegressor')
class BaselineRegressor(estimator.Estimator):
"""A regressor that can establish a simple baseline.
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 4e6010a162..6b54f51ca6 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -39,7 +39,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# TODO(nponomareva): Reveal pruning params here.
_TreeHParams = collections.namedtuple('TreeHParams', [
@@ -712,7 +712,7 @@ def _create_regression_head(label_dimension, weight_column=None):
# pylint: enable=protected-access
-@tf_export('estimator.BoostedTreesClassifier')
+@estimator_export('estimator.BoostedTreesClassifier')
class BoostedTreesClassifier(estimator.Estimator):
"""A Classifier for Tensorflow Boosted Trees models."""
@@ -830,7 +830,7 @@ class BoostedTreesClassifier(estimator.Estimator):
model_fn=_model_fn, model_dir=model_dir, config=config)
-@tf_export('estimator.BoostedTreesRegressor')
+@estimator_export('estimator.BoostedTreesRegressor')
class BoostedTreesRegressor(estimator.Estimator):
"""A Regressor for Tensorflow Boosted Trees models."""
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 1feac36f35..b924ad5df4 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -32,7 +32,7 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.05 is a historical artifact of the initial
# implementation, but seems a reasonable choice.
@@ -201,7 +201,7 @@ def _dnn_model_fn(features,
logits=logits)
-@tf_export('estimator.DNNClassifier')
+@estimator_export('estimator.DNNClassifier')
class DNNClassifier(estimator.Estimator):
"""A classifier for TensorFlow DNN models.
@@ -353,7 +353,7 @@ class DNNClassifier(estimator.Estimator):
warm_start_from=warm_start_from)
-@tf_export('estimator.DNNRegressor')
+@estimator_export('estimator.DNNRegressor')
class DNNRegressor(estimator.Estimator):
"""A regressor for TensorFlow DNN models.
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 95efc0a028..64d81c46ce 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -37,7 +37,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rates are a historical artifact of the initial
# implementation.
@@ -225,7 +225,7 @@ def _dnn_linear_combined_model_fn(features,
logits=logits)
-@tf_export('estimator.DNNLinearCombinedClassifier')
+@estimator_export('estimator.DNNLinearCombinedClassifier')
class DNNLinearCombinedClassifier(estimator.Estimator):
"""An estimator for TensorFlow Linear and DNN joined classification models.
@@ -406,7 +406,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
warm_start_from=warm_start_from)
-@tf_export('estimator.DNNLinearCombinedRegressor')
+@estimator_export('estimator.DNNLinearCombinedRegressor')
class DNNLinearCombinedRegressor(estimator.Estimator):
"""An estimator for TensorFlow Linear and DNN joined models for regression.
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 81657f0c01..705fc3ce06 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -33,7 +33,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import ftrl
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.2 is a historical artifact of the initial
@@ -164,7 +164,7 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
logits=logits)
-@tf_export('estimator.LinearClassifier')
+@estimator_export('estimator.LinearClassifier')
class LinearClassifier(estimator.Estimator):
"""Linear classifier model.
@@ -317,7 +317,7 @@ class LinearClassifier(estimator.Estimator):
warm_start_from=warm_start_from)
-@tf_export('estimator.LinearRegressor')
+@estimator_export('estimator.LinearRegressor')
class LinearRegressor(estimator.Estimator):
"""An estimator for TensorFlow Linear regression problems.
diff --git a/tensorflow/python/estimator/canned/parsing_utils.py b/tensorflow/python/estimator/canned/parsing_utils.py
index 74e5e5a1be..1ae0f1e9f7 100644
--- a/tensorflow/python/estimator/canned/parsing_utils.py
+++ b/tensorflow/python/estimator/canned/parsing_utils.py
@@ -23,10 +23,10 @@ import six
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.classifier_parse_example_spec')
+@estimator_export('estimator.classifier_parse_example_spec')
def classifier_parse_example_spec(feature_columns,
label_key,
label_dtype=dtypes.int64,
@@ -166,7 +166,7 @@ def classifier_parse_example_spec(feature_columns,
return parsing_spec
-@tf_export('estimator.regressor_parse_example_spec')
+@estimator_export('estimator.regressor_parse_example_spec')
def regressor_parse_example_spec(feature_columns,
label_key,
label_dtype=dtypes.float32,
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 4be1af1e66..41c25f1c73 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -66,14 +66,14 @@ from tensorflow.python.util import compat
from tensorflow.python.util import compat_internal
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_VALID_MODEL_FN_ARGS = set(
['features', 'labels', 'mode', 'params', 'self', 'config'])
-@tf_export('estimator.Estimator')
+@estimator_export('estimator.Estimator')
class Estimator(object):
"""Estimator class to train and evaluate TensorFlow models.
@@ -566,7 +566,8 @@ class Estimator(object):
allowed_overrides = set([
'_call_input_fn', '_create_global_step',
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_tf_api_names', '_validate_features_in_predict_input',
+ '_tf_api_names', '_estimator_api_names', '_estimator_api_constants',
+ '_validate_features_in_predict_input',
'_call_model_fn', '_add_meta_graph_for_mode'
])
estimator_members = set([m for m in Estimator.__dict__.keys()
@@ -1634,11 +1635,12 @@ def _has_dataset_or_queue_runner(maybe_tensor):
# Now, check queue.
return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
+
VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name
-tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo)
+estimator_export('estimator.VocabInfo')(VocabInfo)
-@tf_export('estimator.WarmStartSettings')
+@estimator_export('estimator.WarmStartSettings')
class WarmStartSettings(
collections.namedtuple('WarmStartSettings', [
'ckpt_to_initialize_from',
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index ff19a0a7f4..010c0f3f59 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util import compat
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
@@ -93,7 +93,7 @@ def _check_tensor_key(name, error_label='feature'):
raise ValueError('{} keys must be strings: {}.'.format(error_label, name))
-@tf_export('estimator.export.ServingInputReceiver')
+@estimator_export('estimator.export.ServingInputReceiver')
class ServingInputReceiver(
collections.namedtuple(
'ServingInputReceiver',
@@ -161,7 +161,7 @@ class ServingInputReceiver(
receiver_tensors_alternatives=receiver_tensors_alternatives)
-@tf_export('estimator.export.TensorServingInputReceiver')
+@estimator_export('estimator.export.TensorServingInputReceiver')
class TensorServingInputReceiver(
collections.namedtuple(
'TensorServingInputReceiver',
@@ -263,7 +263,7 @@ class SupervisedInputReceiver(
receiver_tensors=receiver_tensors)
-@tf_export('estimator.export.build_parsing_serving_input_receiver_fn')
+@estimator_export('estimator.export.build_parsing_serving_input_receiver_fn')
def build_parsing_serving_input_receiver_fn(feature_spec,
default_batch_size=None):
"""Build a serving_input_receiver_fn expecting fed tf.Examples.
@@ -313,7 +313,7 @@ def _placeholders_from_receiver_tensors_dict(input_vals,
}
-@tf_export('estimator.export.build_raw_serving_input_receiver_fn')
+@estimator_export('estimator.export.build_raw_serving_input_receiver_fn')
def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
"""Build a serving_input_receiver_fn expecting feature Tensors.
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index d387ea2940..6c26d29985 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -26,10 +26,10 @@ import six
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_def_utils
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.export.ExportOutput')
+@estimator_export('estimator.export.ExportOutput')
class ExportOutput(object):
"""Represents an output of a model that can be served.
@@ -100,7 +100,7 @@ class ExportOutput(object):
return output_dict
-@tf_export('estimator.export.ClassificationOutput')
+@estimator_export('estimator.export.ClassificationOutput')
class ClassificationOutput(ExportOutput):
"""Represents the output of a classification head.
@@ -169,7 +169,7 @@ class ClassificationOutput(ExportOutput):
examples, self.classes, self.scores)
-@tf_export('estimator.export.RegressionOutput')
+@estimator_export('estimator.export.RegressionOutput')
class RegressionOutput(ExportOutput):
"""Represents the output of a regression head."""
@@ -202,7 +202,7 @@ class RegressionOutput(ExportOutput):
return signature_def_utils.regression_signature_def(examples, self.value)
-@tf_export('estimator.export.PredictOutput')
+@estimator_export('estimator.export.PredictOutput')
class PredictOutput(ExportOutput):
"""Represents the output of a generic prediction head.
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py
index 5981fa59b7..7cdf840c97 100644
--- a/tensorflow/python/estimator/exporter.py
+++ b/tensorflow/python/estimator/exporter.py
@@ -28,10 +28,10 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging
from tensorflow.python.summary import summary_iterator
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.Exporter')
+@estimator_export('estimator.Exporter')
class Exporter(object):
"""A class representing a type of model export."""
@@ -172,7 +172,7 @@ def _verify_compre_fn_args(compare_fn):
(compare_fn, non_valid_args))
-@tf_export('estimator.BestExporter')
+@estimator_export('estimator.BestExporter')
class BestExporter(Exporter):
"""This class exports the serving graph and checkpoints of the best models.
@@ -367,7 +367,7 @@ class BestExporter(Exporter):
return best_eval_result
-@tf_export('estimator.FinalExporter')
+@estimator_export('estimator.FinalExporter')
class FinalExporter(Exporter):
"""This class exports the serving graph and checkpoints in the end.
@@ -418,7 +418,7 @@ class FinalExporter(Exporter):
is_the_final_export)
-@tf_export('estimator.LatestExporter')
+@estimator_export('estimator.LatestExporter')
class LatestExporter(Exporter):
"""This class regularly exports the serving graph and checkpoints.
diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py
index a6f4712910..035c7c148c 100644
--- a/tensorflow/python/estimator/inputs/numpy_io.py
+++ b/tensorflow/python/estimator/inputs/numpy_io.py
@@ -24,7 +24,7 @@ import numpy as np
from six import string_types
from tensorflow.python.estimator.inputs.queues import feeding_functions
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# Key name to pack the target into dict of `features`. See
# `_get_unique_target_key` for details.
@@ -87,7 +87,7 @@ def _validate_and_convert_features(x):
return ordered_dict_data
-@tf_export('estimator.inputs.numpy_input_fn')
+@estimator_export('estimator.inputs.numpy_input_fn')
def numpy_input_fn(x,
y=None,
batch_size=128,
diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py
index bd06843021..938e244fb3 100644
--- a/tensorflow/python/estimator/inputs/pandas_io.py
+++ b/tensorflow/python/estimator/inputs/pandas_io.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.estimator.inputs.queues import feeding_functions
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
try:
# pylint: disable=g-import-not-at-top
@@ -35,7 +35,7 @@ except ImportError:
HAS_PANDAS = False
-@tf_export('estimator.inputs.pandas_input_fn')
+@estimator_export('estimator.inputs.pandas_input_fn')
def pandas_input_fn(x,
y=None,
batch_size=128,
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 3edf9fe940..c60c7f63ba 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -32,10 +32,10 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import nest
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.ModeKeys')
+@estimator_export('estimator.ModeKeys')
class ModeKeys(object):
"""Standard names for model modes.
@@ -62,7 +62,7 @@ EXPORT_TAG_MAP = {
}
-@tf_export('estimator.EstimatorSpec')
+@estimator_export('estimator.EstimatorSpec')
class EstimatorSpec(
collections.namedtuple('EstimatorSpec', [
'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops',
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index c7707be839..b948ce96e0 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat_internal
from tensorflow.python.util import function_utils
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_USE_DEFAULT = object()
@@ -296,7 +296,7 @@ class TaskType(object):
EVALUATOR = 'evaluator'
-@tf_export('estimator.RunConfig')
+@estimator_export('estimator.RunConfig')
class RunConfig(object):
"""This class specifies the configurations for an `Estimator` run."""
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index fb6a68b4f7..1572af579b 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -35,7 +35,7 @@ from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_MAX_DELAY_SECS = 60
_DELAY_SECS_PER_WORKER = 5
@@ -115,7 +115,7 @@ def _is_google_env():
return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE
-@tf_export('estimator.TrainSpec')
+@estimator_export('estimator.TrainSpec')
class TrainSpec(
collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])):
"""Configuration for the "train" part for the `train_and_evaluate` call.
@@ -167,7 +167,7 @@ class TrainSpec(
cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)
-@tf_export('estimator.EvalSpec')
+@estimator_export('estimator.EvalSpec')
class EvalSpec(
collections.namedtuple('EvalSpec', [
'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',
@@ -263,7 +263,7 @@ class EvalSpec(
throttle_secs=throttle_secs)
-@tf_export('estimator.train_and_evaluate')
+@estimator_export('estimator.train_and_evaluate')
def train_and_evaluate(estimator, train_spec, eval_spec):
"""Train and evaluate the `estimator`.
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index bf3961c692..e154ffb68a 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -41,17 +41,35 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+import functools
import sys
from tensorflow.python.util import tf_decorator
+ESTIMATOR_API_NAME = 'estimator'
+TENSORFLOW_API_NAME = 'tensorflow'
+
+_Attributes = collections.namedtuple(
+ 'ExportedApiAttributes', ['names', 'constants'])
+
+# Attribute values must be unique to each API.
+API_ATTRS = {
+ TENSORFLOW_API_NAME: _Attributes(
+ '_tf_api_names',
+ '_tf_api_constants'),
+ ESTIMATOR_API_NAME: _Attributes(
+ '_estimator_api_names',
+ '_estimator_api_constants')
+}
+
class SymbolAlreadyExposedError(Exception):
"""Raised when adding API names to symbol that already has API names."""
pass
-class tf_export(object): # pylint: disable=invalid-name
+class api_export(object): # pylint: disable=invalid-name
"""Provides ways to export symbols to the TensorFlow API."""
def __init__(self, *args, **kwargs):
@@ -63,15 +81,12 @@ class tf_export(object): # pylint: disable=invalid-name
overrides: List of symbols that this is overriding
(those overrided api exports will be removed). Note: passing overrides
has no effect on exporting a constant.
- allow_multiple_exports: Allows exporting the same symbol multiple
- times with multiple `tf_export` usages. Prefer however, to list all
- of the exported names in a single `tf_export` usage when possible.
-
+ api_name: Name of the API you want to generate (e.g. `tensorflow` or
+ `estimator`). Default is `tensorflow`.
"""
self._names = args
+ self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
- self._allow_multiple_exports = kwargs.get(
- 'allow_multiple_exports', False)
def __call__(self, func):
"""Calls this decorator.
@@ -86,25 +101,24 @@ class tf_export(object): # pylint: disable=invalid-name
SymbolAlreadyExposedError: Raised when a symbol already has API names
and kwarg `allow_multiple_exports` not set.
"""
+ api_names_attr = API_ATTRS[self._api_name].names
+
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
- del undecorated_f._tf_api_names # pylint: disable=protected-access
+ delattr(undecorated_f, api_names_attr)
_, undecorated_func = tf_decorator.unwrap(func)
# Check for an existing api. We check if attribute name is in
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
- if '_tf_api_names' in undecorated_func.__dict__:
- if self._allow_multiple_exports:
- undecorated_func._tf_api_names += self._names # pylint: disable=protected-access
- else:
- raise SymbolAlreadyExposedError(
- 'Symbol %s is already exposed as %s.' %
- (undecorated_func.__name__, undecorated_func._tf_api_names)) # pylint: disable=protected-access
- else:
- undecorated_func._tf_api_names = self._names # pylint: disable=protected-access
+ if api_names_attr in undecorated_func.__dict__:
+ raise SymbolAlreadyExposedError(
+ 'Symbol %s is already exposed as %s.' %
+ (undecorated_func.__name__, getattr(
+ undecorated_func, api_names_attr))) # pylint: disable=protected-access
+ setattr(undecorated_func, api_names_attr, self._names)
return func
def export_constant(self, module_name, name):
@@ -126,8 +140,12 @@ class tf_export(object): # pylint: disable=invalid-name
name: (string) Current constant name.
"""
module = sys.modules[module_name]
- if not hasattr(module, '_tf_api_constants'):
- module._tf_api_constants = [] # pylint: disable=protected-access
+ if not hasattr(module, API_ATTRS[self._api_name].constants):
+ setattr(module, API_ATTRS[self._api_name].constants, [])
# pylint: disable=protected-access
- module._tf_api_constants.append((self._names, name))
+ getattr(module, API_ATTRS[self._api_name].constants).append(
+ (self._names, name))
+
+tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
+estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME)
diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py
index ace3f054ba..b9e26ecb33 100644
--- a/tensorflow/python/util/tf_export_test.py
+++ b/tensorflow/python/util/tf_export_test.py
@@ -128,13 +128,6 @@ class ValidateExportTest(test.TestCase):
with self.assertRaises(tf_export.SymbolAlreadyExposedError):
export_decorator(_test_function)
- def testEAllowMultipleExports(self):
- _test_function._tf_api_names = ['name1', 'name2']
- tf_export.tf_export('nameRed', 'nameBlue', allow_multiple_exports=True)(
- _test_function)
- self.assertEquals(['name1', 'name2', 'nameRed', 'nameBlue'],
- _test_function._tf_api_names)
-
def testOverridesFunction(self):
_test_function2._tf_api_names = ['abc']