diff options
author | Anna R <annarev@google.com> | 2018-07-12 14:46:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-12 14:51:03 -0700 |
commit | da798407b4ff72f1daa629e054ccd47b162c9d58 (patch) | |
tree | 4bc5251f66dd8bb601d73fd3ec8f035b953bbe6a /tensorflow/python/util | |
parent | c5e563e57feee793499fae9c3ce28f5176404749 (diff) |
Support passing TensorFlow API names as a separate v1 argument to tf_export.
PiperOrigin-RevId: 204368026
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/py_checkpoint_reader.i | 1 | ||||
-rw-r--r-- | tensorflow/python/util/stat_summarizer.i | 25 | ||||
-rw-r--r-- | tensorflow/python/util/tf_export.py | 45 | ||||
-rw-r--r-- | tensorflow/python/util/tf_export_test.py | 2 |
4 files changed, 43 insertions, 30 deletions
diff --git a/tensorflow/python/util/py_checkpoint_reader.i b/tensorflow/python/util/py_checkpoint_reader.i index 8004898cbc..1c73f7f06f 100644 --- a/tensorflow/python/util/py_checkpoint_reader.i +++ b/tensorflow/python/util/py_checkpoint_reader.i @@ -166,6 +166,7 @@ def NewCheckpointReader(filepattern): return CheckpointReader(compat.as_bytes(filepattern), status) NewCheckpointReader._tf_api_names = ['train.NewCheckpointReader'] +NewCheckpointReader._tf_api_names_v1 = ['train.NewCheckpointReader'] %} %include "tensorflow/c/checkpoint_reader.h" diff --git a/tensorflow/python/util/stat_summarizer.i b/tensorflow/python/util/stat_summarizer.i index 73fa85494b..a5a7984d91 100644 --- a/tensorflow/python/util/stat_summarizer.i +++ b/tensorflow/python/util/stat_summarizer.i @@ -27,8 +27,8 @@ limitations under the License. %ignoreall -%unignore _NewStatSummarizer; -%unignore _DeleteStatSummarizer; +%unignore NewStatSummarizer; +%unignore DeleteStatSummarizer; %unignore tensorflow; %unignore tensorflow::StatSummarizer; %unignore tensorflow::StatSummarizer::StatSummarizer; @@ -43,20 +43,20 @@ limitations under the License. // TODO(ashankar): Remove the unused argument from the API. %{ -tensorflow::StatSummarizer* _NewStatSummarizer( +tensorflow::StatSummarizer* NewStatSummarizer( const string& unused) { return new tensorflow::StatSummarizer(tensorflow::StatSummarizerOptions()); } %} %{ -void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss) { +void DeleteStatSummarizer(tensorflow::StatSummarizer* ss) { delete ss; } %} -tensorflow::StatSummarizer* _NewStatSummarizer(const string& unused); -void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss); +tensorflow::StatSummarizer* NewStatSummarizer(const string& unused); +void DeleteStatSummarizer(tensorflow::StatSummarizer* ss); %extend tensorflow::StatSummarizer { void ProcessStepStatsStr(const string& step_stats_str) { @@ -76,16 +76,3 @@ void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss); %include "tensorflow/core/util/stat_summarizer_options.h" %include "tensorflow/core/util/stat_summarizer.h" %unignoreall - -%insert("python") %{ - -# Wrapping NewStatSummarizer and DeletStatSummarizer because -# SWIG-generated functions are built-in functions and do not support -# setting _tf_api_names attribute. - -def NewStatSummarizer(unused): - return _NewStatSummarizer(unused) - -def DeleteStatSummarizer(stat_summarizer): - _DeleteStatSummarizer(stat_summarizer) -%} diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index e154ffb68a..c362d588ab 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -63,6 +63,15 @@ API_ATTRS = { '_estimator_api_constants') } +API_ATTRS_V1 = { + TENSORFLOW_API_NAME: _Attributes( + '_tf_api_names_v1', + '_tf_api_constants_v1'), + ESTIMATOR_API_NAME: _Attributes( + '_estimator_api_names_v1', + '_estimator_api_constants_v1') +} + class SymbolAlreadyExposedError(Exception): """Raised when adding API names to symbol that already has API names.""" @@ -78,13 +87,16 @@ class api_export(object): # pylint: disable=invalid-name Args: *args: API names in dot delimited format. **kwargs: Optional keyed arguments. - overrides: List of symbols that this is overriding + v1: Names for the TensorFlow V1 API. If not set, we will use V2 API + names both for TensorFlow V1 and V2 APIs. + overrides: List of symbols that this is overriding (those overrided api exports will be removed). Note: passing overrides has no effect on exporting a constant. - api_name: Name of the API you want to generate (e.g. `tensorflow` or + api_name: Name of the API you want to generate (e.g. `tensorflow` or `estimator`). Default is `tensorflow`. """ self._names = args + self._names_v1 = kwargs.get('v1', args) self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) self._overrides = kwargs.get('overrides', []) @@ -102,24 +114,27 @@ class api_export(object): # pylint: disable=invalid-name and kwarg `allow_multiple_exports` not set. """ api_names_attr = API_ATTRS[self._api_name].names - + api_names_attr_v1 = API_ATTRS_V1[self._api_name].names # Undecorate overridden names for f in self._overrides: _, undecorated_f = tf_decorator.unwrap(f) delattr(undecorated_f, api_names_attr) + delattr(undecorated_f, api_names_attr_v1) _, undecorated_func = tf_decorator.unwrap(func) + self.set_attr(undecorated_func, api_names_attr, self._names) + self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1) + return func + def set_attr(self, func, api_names_attr, names): # Check for an existing api. We check if attribute name is in # __dict__ instead of using hasattr to verify that subclasses have # their own _tf_api_names as opposed to just inheriting it. - if api_names_attr in undecorated_func.__dict__: + if api_names_attr in func.__dict__: raise SymbolAlreadyExposedError( 'Symbol %s is already exposed as %s.' % - (undecorated_func.__name__, getattr( - undecorated_func, api_names_attr))) # pylint: disable=protected-access - setattr(undecorated_func, api_names_attr, self._names) - return func + (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access + setattr(func, api_names_attr, names) def export_constant(self, module_name, name): """Store export information for constants/string literals. @@ -140,12 +155,20 @@ class api_export(object): # pylint: disable=invalid-name name: (string) Current constant name. """ module = sys.modules[module_name] - if not hasattr(module, API_ATTRS[self._api_name].constants): - setattr(module, API_ATTRS[self._api_name].constants, []) + api_constants_attr = API_ATTRS[self._api_name].constants + api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants + + if not hasattr(module, api_constants_attr): + setattr(module, api_constants_attr, []) # pylint: disable=protected-access - getattr(module, API_ATTRS[self._api_name].constants).append( + getattr(module, api_constants_attr).append( (self._names, name)) + if not hasattr(module, api_constants_attr_v1): + setattr(module, api_constants_attr_v1, []) + getattr(module, api_constants_attr_v1).append( + (self._names_v1, name)) + tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME) diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py index b9e26ecb33..4ae1dc55e0 100644 --- a/tensorflow/python/util/tf_export_test.py +++ b/tensorflow/python/util/tf_export_test.py @@ -60,6 +60,8 @@ class ValidateExportTest(test.TestCase): for symbol in [_test_function, _test_function, TestClassA, TestClassB]: if hasattr(symbol, '_tf_api_names'): del symbol._tf_api_names + if hasattr(symbol, '_tf_api_names_v1'): + del symbol._tf_api_names_v1 def _CreateMockModule(self, name): mock_module = self.MockModule(name) |