aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-26 22:00:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 22:03:37 -0700
commita40cfd42e20d7e4520c1306666c9dfee97eb0a2e (patch)
tree380100ade305a7b1fe8e7baa7eed2197daf1eabb /tensorflow/python/feature_column
parent941e757a2364bb2e7cf41b8d980d7639849c6c5d (diff)
Automated rollback of commit e00d7744dbab5c73e4d8ffa8a7d361f7b2dcefff
PiperOrigin-RevId: 214721004
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py33
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py14
2 files changed, 37 insertions, 10 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 9984379e9d..226e273660 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -170,7 +170,8 @@ def _internal_input_layer(features,
trainable=True,
cols_to_vars=None,
scope=None,
- cols_to_output_tensors=None):
+ cols_to_output_tensors=None,
+ from_template=False):
"""See input_layer. `scope` is a name or variable scope to use."""
feature_columns = _normalize_feature_columns(feature_columns)
@@ -186,10 +187,7 @@ def _internal_input_layer(features,
if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
- # a non-None `scope` can allow for variable reuse, when, e.g., this function
- # is wrapped by a `make_template`.
- with variable_scope.variable_scope(
- scope, default_name='input_layer', values=features.values()):
+ def _get_logits(): # pylint: disable=missing-docstring
builder = _LazyBuilder(features)
output_tensors = []
ordered_columns = []
@@ -217,6 +215,16 @@ def _internal_input_layer(features,
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
+ # If we're constructing from the `make_template`, that by default adds a
+ # variable scope with the name of the layer. In that case, we dont want to
+ # add another `variable_scope` as that would break checkpoints.
+ if from_template:
+ return _get_logits()
+ else:
+ with variable_scope.variable_scope(
+ scope, default_name='input_layer', values=features.values()):
+ return _get_logits()
+
@tf_export('feature_column.input_layer')
def input_layer(features,
@@ -301,17 +309,18 @@ class InputLayer(object):
feature_columns,
weight_collections=None,
trainable=True,
- cols_to_vars=None):
+ cols_to_vars=None,
+ name='feature_column_input_layer',
+ create_scope_now=True):
"""See `input_layer`."""
self._feature_columns = feature_columns
self._weight_collections = weight_collections
self._trainable = trainable
self._cols_to_vars = cols_to_vars
+ self._name = name
self._input_layer_template = template.make_template(
- 'feature_column_input_layer',
- _internal_input_layer,
- create_scope_now_=True)
+ self._name, _internal_input_layer, create_scope_now_=create_scope_now)
self._scope = self._input_layer_template.variable_scope
def __call__(self, features):
@@ -321,7 +330,11 @@ class InputLayer(object):
weight_collections=self._weight_collections,
trainable=self._trainable,
cols_to_vars=None,
- scope=self._scope)
+ from_template=True)
+
+ @property
+ def name(self):
+ return self._name
@property
def non_trainable_variables(self):
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 57f7af7635..b62c16ea5a 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -2045,6 +2045,14 @@ class DenseColumn(FeatureColumn):
pass
+def is_feature_column_v2(feature_columns):
+ """Returns True if all feature columns are V2."""
+ for feature_column in feature_columns:
+ if not isinstance(feature_column, FeatureColumn):
+ return False
+ return True
+
+
def _create_weighted_sum(column,
transformation_cache,
state_manager,
@@ -2782,6 +2790,12 @@ class SharedEmbeddingStateManager(Layer):
return self._var_dict[name]
+def maybe_create_shared_state_manager(feature_columns):
+ if is_feature_column_v2(feature_columns):
+ return SharedEmbeddingStateManager()
+ return None
+
+
class SharedEmbeddingColumn(
DenseColumn, SequenceDenseColumn,
collections.namedtuple(