diff options
author | Rohan Jain <rohanj@google.com> | 2018-09-26 22:00:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 22:03:37 -0700 |
commit | a40cfd42e20d7e4520c1306666c9dfee97eb0a2e (patch) | |
tree | 380100ade305a7b1fe8e7baa7eed2197daf1eabb /tensorflow/python/feature_column | |
parent | 941e757a2364bb2e7cf41b8d980d7639849c6c5d (diff) |
Automated rollback of commit e00d7744dbab5c73e4d8ffa8a7d361f7b2dcefff
PiperOrigin-RevId: 214721004
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 33 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_v2.py | 14 |
2 files changed, 37 insertions, 10 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 9984379e9d..226e273660 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -170,7 +170,8 @@ def _internal_input_layer(features, trainable=True, cols_to_vars=None, scope=None, - cols_to_output_tensors=None): + cols_to_output_tensors=None, + from_template=False): """See input_layer. `scope` is a name or variable scope to use.""" feature_columns = _normalize_feature_columns(feature_columns) @@ -186,10 +187,7 @@ def _internal_input_layer(features, if ops.GraphKeys.MODEL_VARIABLES not in weight_collections: weight_collections.append(ops.GraphKeys.MODEL_VARIABLES) - # a non-None `scope` can allow for variable reuse, when, e.g., this function - # is wrapped by a `make_template`. - with variable_scope.variable_scope( - scope, default_name='input_layer', values=features.values()): + def _get_logits(): # pylint: disable=missing-docstring builder = _LazyBuilder(features) output_tensors = [] ordered_columns = [] @@ -217,6 +215,16 @@ def _internal_input_layer(features, _verify_static_batch_size_equality(output_tensors, ordered_columns) return array_ops.concat(output_tensors, 1) + # If we're constructing from the `make_template`, that by default adds a + # variable scope with the name of the layer. In that case, we dont want to + # add another `variable_scope` as that would break checkpoints. + if from_template: + return _get_logits() + else: + with variable_scope.variable_scope( + scope, default_name='input_layer', values=features.values()): + return _get_logits() + @tf_export('feature_column.input_layer') def input_layer(features, @@ -301,17 +309,18 @@ class InputLayer(object): feature_columns, weight_collections=None, trainable=True, - cols_to_vars=None): + cols_to_vars=None, + name='feature_column_input_layer', + create_scope_now=True): """See `input_layer`.""" self._feature_columns = feature_columns self._weight_collections = weight_collections self._trainable = trainable self._cols_to_vars = cols_to_vars + self._name = name self._input_layer_template = template.make_template( - 'feature_column_input_layer', - _internal_input_layer, - create_scope_now_=True) + self._name, _internal_input_layer, create_scope_now_=create_scope_now) self._scope = self._input_layer_template.variable_scope def __call__(self, features): @@ -321,7 +330,11 @@ class InputLayer(object): weight_collections=self._weight_collections, trainable=self._trainable, cols_to_vars=None, - scope=self._scope) + from_template=True) + + @property + def name(self): + return self._name @property def non_trainable_variables(self): diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 57f7af7635..b62c16ea5a 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -2045,6 +2045,14 @@ class DenseColumn(FeatureColumn): pass +def is_feature_column_v2(feature_columns): + """Returns True if all feature columns are V2.""" + for feature_column in feature_columns: + if not isinstance(feature_column, FeatureColumn): + return False + return True + + def _create_weighted_sum(column, transformation_cache, state_manager, @@ -2782,6 +2790,12 @@ class SharedEmbeddingStateManager(Layer): return self._var_dict[name] +def maybe_create_shared_state_manager(feature_columns): + if is_feature_column_v2(feature_columns): + return SharedEmbeddingStateManager() + return None + + class SharedEmbeddingColumn( DenseColumn, SequenceDenseColumn, collections.namedtuple( |