aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-18 19:39:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 19:43:44 -0700
commit9fe177881224571aff0c267593f747f5fd7a2967 (patch)
tree9c5051a7336ac9832171ebfee8e610ba550d0f1e /tensorflow/python/feature_column
parent9ee75bb6e29007b8b5ea4a6d981996d8a4d88373 (diff)
Getting DNNModel to work with the new feature columns.
PiperOrigin-RevId: 213561495
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py12
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py14
2 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 9984379e9d..0d189320da 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -301,17 +301,17 @@ class InputLayer(object):
feature_columns,
weight_collections=None,
trainable=True,
- cols_to_vars=None):
+ cols_to_vars=None,
+ name='feature_column_input_layer'):
"""See `input_layer`."""
self._feature_columns = feature_columns
self._weight_collections = weight_collections
self._trainable = trainable
self._cols_to_vars = cols_to_vars
+ self._name = name
self._input_layer_template = template.make_template(
- 'feature_column_input_layer',
- _internal_input_layer,
- create_scope_now_=True)
+ self._name, _internal_input_layer, create_scope_now_=True)
self._scope = self._input_layer_template.variable_scope
def __call__(self, features):
@@ -324,6 +324,10 @@ class InputLayer(object):
scope=self._scope)
@property
+ def name(self):
+ return self._name
+
+ @property
def non_trainable_variables(self):
return self._input_layer_template.non_trainable_variables
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 28c5c82d2c..289f6d0d14 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -2045,6 +2045,14 @@ class DenseColumn(FeatureColumn):
pass
+def is_feature_column_v2(feature_columns):
+ """Returns True if all feature columns are V2."""
+ for feature_column in feature_columns:
+ if not isinstance(feature_column, FeatureColumn):
+ return False
+ return True
+
+
def _create_weighted_sum(column,
transformation_cache,
state_manager,
@@ -2782,6 +2790,12 @@ class SharedEmbeddingStateManager(Layer):
return self._var_dict[name]
+def maybe_create_shared_state_manager(feature_columns):
+ if is_feature_column_v2(feature_columns):
+ return SharedEmbeddingStateManager()
+ return None
+
+
class SharedEmbeddingColumn(
DenseColumn, SequenceDenseColumn,
collections.namedtuple(