diff options
author | Rohan Jain <rohanj@google.com> | 2018-09-18 19:39:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 19:43:44 -0700 |
commit | 9fe177881224571aff0c267593f747f5fd7a2967 (patch) | |
tree | 9c5051a7336ac9832171ebfee8e610ba550d0f1e /tensorflow/python/feature_column | |
parent | 9ee75bb6e29007b8b5ea4a6d981996d8a4d88373 (diff) |
Getting DNNModel to work with the new feature columns.
PiperOrigin-RevId: 213561495
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 12 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_v2.py | 14 |
2 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 9984379e9d..0d189320da 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -301,17 +301,17 @@ class InputLayer(object): feature_columns, weight_collections=None, trainable=True, - cols_to_vars=None): + cols_to_vars=None, + name='feature_column_input_layer'): """See `input_layer`.""" self._feature_columns = feature_columns self._weight_collections = weight_collections self._trainable = trainable self._cols_to_vars = cols_to_vars + self._name = name self._input_layer_template = template.make_template( - 'feature_column_input_layer', - _internal_input_layer, - create_scope_now_=True) + self._name, _internal_input_layer, create_scope_now_=True) self._scope = self._input_layer_template.variable_scope def __call__(self, features): @@ -324,6 +324,10 @@ class InputLayer(object): scope=self._scope) @property + def name(self): + return self._name + + @property def non_trainable_variables(self): return self._input_layer_template.non_trainable_variables diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 28c5c82d2c..289f6d0d14 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -2045,6 +2045,14 @@ class DenseColumn(FeatureColumn): pass +def is_feature_column_v2(feature_columns): + """Returns True if all feature columns are V2.""" + for feature_column in feature_columns: + if not isinstance(feature_column, FeatureColumn): + return False + return True + + def _create_weighted_sum(column, transformation_cache, state_manager, @@ -2782,6 +2790,12 @@ class SharedEmbeddingStateManager(Layer): return self._var_dict[name] +def maybe_create_shared_state_manager(feature_columns): + if is_feature_column_v2(feature_columns): + return SharedEmbeddingStateManager() + return None + + class SharedEmbeddingColumn( DenseColumn, SequenceDenseColumn, collections.namedtuple( |