diff options
Diffstat (limited to 'tensorflow/python')
19 files changed, 112 insertions, 78 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index d538c6c415..c0d63b79a6 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -12,6 +12,10 @@ py_library( name = "estimator_py", srcs = ["estimator_lib.py"], srcs_version = "PY2AND3", + visibility = [ + "//tensorflow:__pkg__", + "//tensorflow:internal", + ], deps = [ ":baseline", ":boosted_trees", diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD new file mode 100644 index 0000000000..cddee9b8f3 --- /dev/null +++ b/tensorflow/python/estimator/api/BUILD @@ -0,0 +1,17 @@ +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/tools/api/generator:api_gen.bzl", "gen_api_init_files") +load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") + +gen_api_init_files( + name = "estimator_python_api_gen", + api_name = "estimator", + output_files = ESTIMATOR_API_INIT_FILES, + package = "tensorflow.python.estimator", +) diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py index 980c057372..3c6816cb03 100644 --- a/tensorflow/python/estimator/canned/baseline.py +++ b/tensorflow/python/estimator/canned/baseline.py @@ -59,7 +59,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.training import training_util -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # The default learning rate of 0.3 is a historical artifact of the initial # implementation, but seems a reasonable choice. @@ -174,7 +174,7 @@ def _baseline_model_fn(features, labels, mode, head, optimizer, train_op_fn=train_op_fn) -@tf_export('estimator.BaselineClassifier') +@estimator_export('estimator.BaselineClassifier') class BaselineClassifier(estimator.Estimator): """A classifier that can establish a simple baseline. @@ -277,7 +277,7 @@ class BaselineClassifier(estimator.Estimator): config=config) -@tf_export('estimator.BaselineRegressor') +@estimator_export('estimator.BaselineRegressor') class BaselineRegressor(estimator.Estimator): """A regressor that can establish a simple baseline. diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 4e6010a162..6b54f51ca6 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -39,7 +39,7 @@ from tensorflow.python.summary import summary from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # TODO(nponomareva): Reveal pruning params here. _TreeHParams = collections.namedtuple('TreeHParams', [ @@ -712,7 +712,7 @@ def _create_regression_head(label_dimension, weight_column=None): # pylint: enable=protected-access -@tf_export('estimator.BoostedTreesClassifier') +@estimator_export('estimator.BoostedTreesClassifier') class BoostedTreesClassifier(estimator.Estimator): """A Classifier for Tensorflow Boosted Trees models.""" @@ -830,7 +830,7 @@ class BoostedTreesClassifier(estimator.Estimator): model_fn=_model_fn, model_dir=model_dir, config=config) -@tf_export('estimator.BoostedTreesRegressor') +@estimator_export('estimator.BoostedTreesRegressor') class BoostedTreesRegressor(estimator.Estimator): """A Regressor for Tensorflow Boosted Trees models.""" diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 1feac36f35..b924ad5df4 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -32,7 +32,7 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # The default learning rate of 0.05 is a historical artifact of the initial # implementation, but seems a reasonable choice. @@ -201,7 +201,7 @@ def _dnn_model_fn(features, logits=logits) -@tf_export('estimator.DNNClassifier') +@estimator_export('estimator.DNNClassifier') class DNNClassifier(estimator.Estimator): """A classifier for TensorFlow DNN models. @@ -353,7 +353,7 @@ class DNNClassifier(estimator.Estimator): warm_start_from=warm_start_from) -@tf_export('estimator.DNNRegressor') +@estimator_export('estimator.DNNRegressor') class DNNRegressor(estimator.Estimator): """A regressor for TensorFlow DNN models. diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 95efc0a028..64d81c46ce 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -37,7 +37,7 @@ from tensorflow.python.summary import summary from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training_util -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # The default learning rates are a historical artifact of the initial # implementation. @@ -225,7 +225,7 @@ def _dnn_linear_combined_model_fn(features, logits=logits) -@tf_export('estimator.DNNLinearCombinedClassifier') +@estimator_export('estimator.DNNLinearCombinedClassifier') class DNNLinearCombinedClassifier(estimator.Estimator): """An estimator for TensorFlow Linear and DNN joined classification models. @@ -406,7 +406,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator): warm_start_from=warm_start_from) -@tf_export('estimator.DNNLinearCombinedRegressor') +@estimator_export('estimator.DNNLinearCombinedRegressor') class DNNLinearCombinedRegressor(estimator.Estimator): """An estimator for TensorFlow Linear and DNN joined models for regression. diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index 81657f0c01..705fc3ce06 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import ftrl -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # The default learning rate of 0.2 is a historical artifact of the initial @@ -164,7 +164,7 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, logits=logits) -@tf_export('estimator.LinearClassifier') +@estimator_export('estimator.LinearClassifier') class LinearClassifier(estimator.Estimator): """Linear classifier model. @@ -317,7 +317,7 @@ class LinearClassifier(estimator.Estimator): warm_start_from=warm_start_from) -@tf_export('estimator.LinearRegressor') +@estimator_export('estimator.LinearRegressor') class LinearRegressor(estimator.Estimator): """An estimator for TensorFlow Linear regression problems. diff --git a/tensorflow/python/estimator/canned/parsing_utils.py b/tensorflow/python/estimator/canned/parsing_utils.py index 74e5e5a1be..1ae0f1e9f7 100644 --- a/tensorflow/python/estimator/canned/parsing_utils.py +++ b/tensorflow/python/estimator/canned/parsing_utils.py @@ -23,10 +23,10 @@ import six from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import dtypes from tensorflow.python.ops import parsing_ops -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export -@tf_export('estimator.classifier_parse_example_spec') +@estimator_export('estimator.classifier_parse_example_spec') def classifier_parse_example_spec(feature_columns, label_key, label_dtype=dtypes.int64, @@ -166,7 +166,7 @@ def classifier_parse_example_spec(feature_columns, return parsing_spec -@tf_export('estimator.regressor_parse_example_spec') +@estimator_export('estimator.regressor_parse_example_spec') def regressor_parse_example_spec(feature_columns, label_key, label_dtype=dtypes.float32, diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 4be1af1e66..41c25f1c73 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -66,14 +66,14 @@ from tensorflow.python.util import compat from tensorflow.python.util import compat_internal from tensorflow.python.util import function_utils from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export _VALID_MODEL_FN_ARGS = set( ['features', 'labels', 'mode', 'params', 'self', 'config']) -@tf_export('estimator.Estimator') +@estimator_export('estimator.Estimator') class Estimator(object): """Estimator class to train and evaluate TensorFlow models. @@ -566,7 +566,8 @@ class Estimator(object): allowed_overrides = set([ '_call_input_fn', '_create_global_step', '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks', - '_tf_api_names', '_validate_features_in_predict_input', + '_tf_api_names', '_estimator_api_names', '_estimator_api_constants', + '_validate_features_in_predict_input', '_call_model_fn', '_add_meta_graph_for_mode' ]) estimator_members = set([m for m in Estimator.__dict__.keys() @@ -1634,11 +1635,12 @@ def _has_dataset_or_queue_runner(maybe_tensor): # Now, check queue. return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS) + VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name -tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo) +estimator_export('estimator.VocabInfo')(VocabInfo) -@tf_export('estimator.WarmStartSettings') +@estimator_export('estimator.WarmStartSettings') class WarmStartSettings( collections.namedtuple('WarmStartSettings', [ 'ckpt_to_initialize_from', diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index ff19a0a7f4..010c0f3f59 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.util import compat -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' @@ -93,7 +93,7 @@ def _check_tensor_key(name, error_label='feature'): raise ValueError('{} keys must be strings: {}.'.format(error_label, name)) -@tf_export('estimator.export.ServingInputReceiver') +@estimator_export('estimator.export.ServingInputReceiver') class ServingInputReceiver( collections.namedtuple( 'ServingInputReceiver', @@ -161,7 +161,7 @@ class ServingInputReceiver( receiver_tensors_alternatives=receiver_tensors_alternatives) -@tf_export('estimator.export.TensorServingInputReceiver') +@estimator_export('estimator.export.TensorServingInputReceiver') class TensorServingInputReceiver( collections.namedtuple( 'TensorServingInputReceiver', @@ -263,7 +263,7 @@ class SupervisedInputReceiver( receiver_tensors=receiver_tensors) -@tf_export('estimator.export.build_parsing_serving_input_receiver_fn') +@estimator_export('estimator.export.build_parsing_serving_input_receiver_fn') def build_parsing_serving_input_receiver_fn(feature_spec, default_batch_size=None): """Build a serving_input_receiver_fn expecting fed tf.Examples. @@ -313,7 +313,7 @@ def _placeholders_from_receiver_tensors_dict(input_vals, } -@tf_export('estimator.export.build_raw_serving_input_receiver_fn') +@estimator_export('estimator.export.build_raw_serving_input_receiver_fn') def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """Build a serving_input_receiver_fn expecting feature Tensors. diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index d387ea2940..6c26d29985 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -26,10 +26,10 @@ import six from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.saved_model import signature_def_utils -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export -@tf_export('estimator.export.ExportOutput') +@estimator_export('estimator.export.ExportOutput') class ExportOutput(object): """Represents an output of a model that can be served. @@ -100,7 +100,7 @@ class ExportOutput(object): return output_dict -@tf_export('estimator.export.ClassificationOutput') +@estimator_export('estimator.export.ClassificationOutput') class ClassificationOutput(ExportOutput): """Represents the output of a classification head. @@ -169,7 +169,7 @@ class ClassificationOutput(ExportOutput): examples, self.classes, self.scores) -@tf_export('estimator.export.RegressionOutput') +@estimator_export('estimator.export.RegressionOutput') class RegressionOutput(ExportOutput): """Represents the output of a regression head.""" @@ -202,7 +202,7 @@ class RegressionOutput(ExportOutput): return signature_def_utils.regression_signature_def(examples, self.value) -@tf_export('estimator.export.PredictOutput') +@estimator_export('estimator.export.PredictOutput') class PredictOutput(ExportOutput): """Represents the output of a generic prediction head. diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py index 5981fa59b7..7cdf840c97 100644 --- a/tensorflow/python/estimator/exporter.py +++ b/tensorflow/python/estimator/exporter.py @@ -28,10 +28,10 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging from tensorflow.python.summary import summary_iterator -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export -@tf_export('estimator.Exporter') +@estimator_export('estimator.Exporter') class Exporter(object): """A class representing a type of model export.""" @@ -172,7 +172,7 @@ def _verify_compre_fn_args(compare_fn): (compare_fn, non_valid_args)) -@tf_export('estimator.BestExporter') +@estimator_export('estimator.BestExporter') class BestExporter(Exporter): """This class exports the serving graph and checkpoints of the best models. @@ -367,7 +367,7 @@ class BestExporter(Exporter): return best_eval_result -@tf_export('estimator.FinalExporter') +@estimator_export('estimator.FinalExporter') class FinalExporter(Exporter): """This class exports the serving graph and checkpoints in the end. @@ -418,7 +418,7 @@ class FinalExporter(Exporter): is_the_final_export) -@tf_export('estimator.LatestExporter') +@estimator_export('estimator.LatestExporter') class LatestExporter(Exporter): """This class regularly exports the serving graph and checkpoints. diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py index a6f4712910..035c7c148c 100644 --- a/tensorflow/python/estimator/inputs/numpy_io.py +++ b/tensorflow/python/estimator/inputs/numpy_io.py @@ -24,7 +24,7 @@ import numpy as np from six import string_types from tensorflow.python.estimator.inputs.queues import feeding_functions -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export # Key name to pack the target into dict of `features`. See # `_get_unique_target_key` for details. @@ -87,7 +87,7 @@ def _validate_and_convert_features(x): return ordered_dict_data -@tf_export('estimator.inputs.numpy_input_fn') +@estimator_export('estimator.inputs.numpy_input_fn') def numpy_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py index bd06843021..938e244fb3 100644 --- a/tensorflow/python/estimator/inputs/pandas_io.py +++ b/tensorflow/python/estimator/inputs/pandas_io.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.estimator.inputs.queues import feeding_functions -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export try: # pylint: disable=g-import-not-at-top @@ -35,7 +35,7 @@ except ImportError: HAS_PANDAS = False -@tf_export('estimator.inputs.pandas_input_fn') +@estimator_export('estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 3edf9fe940..c60c7f63ba 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -32,10 +32,10 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export -@tf_export('estimator.ModeKeys') +@estimator_export('estimator.ModeKeys') class ModeKeys(object): """Standard names for model modes. @@ -62,7 +62,7 @@ EXPORT_TAG_MAP = { } -@tf_export('estimator.EstimatorSpec') +@estimator_export('estimator.EstimatorSpec') class EstimatorSpec( collections.namedtuple('EstimatorSpec', [ 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops', diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index c7707be839..b948ce96e0 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat_internal from tensorflow.python.util import function_utils -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export _USE_DEFAULT = object() @@ -296,7 +296,7 @@ class TaskType(object): EVALUATOR = 'evaluator' -@tf_export('estimator.RunConfig') +@estimator_export('estimator.RunConfig') class RunConfig(object): """This class specifies the configurations for an `Estimator` run.""" diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index fb6a68b4f7..1572af579b 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -35,7 +35,7 @@ from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import server_lib from tensorflow.python.training import session_run_hook from tensorflow.python.util import compat -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import estimator_export _MAX_DELAY_SECS = 60 _DELAY_SECS_PER_WORKER = 5 @@ -115,7 +115,7 @@ def _is_google_env(): return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE -@tf_export('estimator.TrainSpec') +@estimator_export('estimator.TrainSpec') class TrainSpec( collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])): """Configuration for the "train" part for the `train_and_evaluate` call. @@ -167,7 +167,7 @@ class TrainSpec( cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks) -@tf_export('estimator.EvalSpec') +@estimator_export('estimator.EvalSpec') class EvalSpec( collections.namedtuple('EvalSpec', [ 'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs', @@ -263,7 +263,7 @@ class EvalSpec( throttle_secs=throttle_secs) -@tf_export('estimator.train_and_evaluate') +@estimator_export('estimator.train_and_evaluate') def train_and_evaluate(estimator, train_spec, eval_spec): """Train and evaluate the `estimator`. diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index bf3961c692..e154ffb68a 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -41,17 +41,35 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import functools import sys from tensorflow.python.util import tf_decorator +ESTIMATOR_API_NAME = 'estimator' +TENSORFLOW_API_NAME = 'tensorflow' + +_Attributes = collections.namedtuple( + 'ExportedApiAttributes', ['names', 'constants']) + +# Attribute values must be unique to each API. +API_ATTRS = { + TENSORFLOW_API_NAME: _Attributes( + '_tf_api_names', + '_tf_api_constants'), + ESTIMATOR_API_NAME: _Attributes( + '_estimator_api_names', + '_estimator_api_constants') +} + class SymbolAlreadyExposedError(Exception): """Raised when adding API names to symbol that already has API names.""" pass -class tf_export(object): # pylint: disable=invalid-name +class api_export(object): # pylint: disable=invalid-name """Provides ways to export symbols to the TensorFlow API.""" def __init__(self, *args, **kwargs): @@ -63,15 +81,12 @@ class tf_export(object): # pylint: disable=invalid-name overrides: List of symbols that this is overriding (those overrided api exports will be removed). Note: passing overrides has no effect on exporting a constant. - allow_multiple_exports: Allows exporting the same symbol multiple - times with multiple `tf_export` usages. Prefer however, to list all - of the exported names in a single `tf_export` usage when possible. - + api_name: Name of the API you want to generate (e.g. `tensorflow` or + `estimator`). Default is `tensorflow`. """ self._names = args + self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) self._overrides = kwargs.get('overrides', []) - self._allow_multiple_exports = kwargs.get( - 'allow_multiple_exports', False) def __call__(self, func): """Calls this decorator. @@ -86,25 +101,24 @@ class tf_export(object): # pylint: disable=invalid-name SymbolAlreadyExposedError: Raised when a symbol already has API names and kwarg `allow_multiple_exports` not set. """ + api_names_attr = API_ATTRS[self._api_name].names + # Undecorate overridden names for f in self._overrides: _, undecorated_f = tf_decorator.unwrap(f) - del undecorated_f._tf_api_names # pylint: disable=protected-access + delattr(undecorated_f, api_names_attr) _, undecorated_func = tf_decorator.unwrap(func) # Check for an existing api. We check if attribute name is in # __dict__ instead of using hasattr to verify that subclasses have # their own _tf_api_names as opposed to just inheriting it. - if '_tf_api_names' in undecorated_func.__dict__: - if self._allow_multiple_exports: - undecorated_func._tf_api_names += self._names # pylint: disable=protected-access - else: - raise SymbolAlreadyExposedError( - 'Symbol %s is already exposed as %s.' % - (undecorated_func.__name__, undecorated_func._tf_api_names)) # pylint: disable=protected-access - else: - undecorated_func._tf_api_names = self._names # pylint: disable=protected-access + if api_names_attr in undecorated_func.__dict__: + raise SymbolAlreadyExposedError( + 'Symbol %s is already exposed as %s.' % + (undecorated_func.__name__, getattr( + undecorated_func, api_names_attr))) # pylint: disable=protected-access + setattr(undecorated_func, api_names_attr, self._names) return func def export_constant(self, module_name, name): @@ -126,8 +140,12 @@ class tf_export(object): # pylint: disable=invalid-name name: (string) Current constant name. """ module = sys.modules[module_name] - if not hasattr(module, '_tf_api_constants'): - module._tf_api_constants = [] # pylint: disable=protected-access + if not hasattr(module, API_ATTRS[self._api_name].constants): + setattr(module, API_ATTRS[self._api_name].constants, []) # pylint: disable=protected-access - module._tf_api_constants.append((self._names, name)) + getattr(module, API_ATTRS[self._api_name].constants).append( + (self._names, name)) + +tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) +estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME) diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py index ace3f054ba..b9e26ecb33 100644 --- a/tensorflow/python/util/tf_export_test.py +++ b/tensorflow/python/util/tf_export_test.py @@ -128,13 +128,6 @@ class ValidateExportTest(test.TestCase): with self.assertRaises(tf_export.SymbolAlreadyExposedError): export_decorator(_test_function) - def testEAllowMultipleExports(self): - _test_function._tf_api_names = ['name1', 'name2'] - tf_export.tf_export('nameRed', 'nameBlue', allow_multiple_exports=True)( - _test_function) - self.assertEquals(['name1', 'name2', 'nameRed', 'nameBlue'], - _test_function._tf_api_names) - def testOverridesFunction(self): _test_function2._tf_api_names = ['abc'] |