aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-28 15:07:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 15:11:09 -0700
commitdee0481c07ed952d01b12704c89e50869a383c68 (patch)
treec1f4c239a4e0e39e5bab25bfa2b3220bd96d509e /tensorflow/python/feature_column
parent5863cad53afad2fcc5d8a8dac7c2cf88e0e8ebb9 (diff)
Adding FeatureColumn V2 support for linear canned estimators.
Since we now have support for FeatureColumnV2 for both DNN and Linear models, adding tests for the combined canned estimators as well. PiperOrigin-RevId: 215002573
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/BUILD2
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py100
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py4
3 files changed, 58 insertions, 48 deletions
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 5800b693b4..ac53a84eef 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -156,7 +156,7 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/estimator:numpy_io",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index a8d5bfb437..b79373c475 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -271,6 +271,7 @@ class _StateManagerImpl(StateManager):
dtype=dtype,
initializer=initializer,
trainable=self._trainable and trainable,
+ use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
@@ -383,8 +384,8 @@ class FeatureLayer(Layer):
if isinstance(column, SharedEmbeddingColumn):
column.create_state(self._shared_state_manager)
else:
- with variable_scope.variable_scope(None, default_name=self.name):
- with variable_scope.variable_scope(None, default_name=column.name):
+ with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
column.create_state(self._state_manager)
super(FeatureLayer, self).build(None)
@@ -414,19 +415,20 @@ class FeatureLayer(Layer):
output_tensors = []
ordered_columns = []
for column in sorted(self._feature_columns, key=lambda x: x.name):
- ordered_columns.append(column)
- if isinstance(column, SharedEmbeddingColumn):
- tensor = column.get_dense_tensor(transformation_cache,
- self._shared_state_manager)
- else:
- tensor = column.get_dense_tensor(transformation_cache,
- self._state_manager)
- num_elements = column.variable_shape.num_elements()
- batch_size = array_ops.shape(tensor)[0]
- tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- output_tensors.append(tensor)
- if cols_to_output_tensors is not None:
- cols_to_output_tensors[column] = tensor
+ with ops.name_scope(column.name):
+ ordered_columns.append(column)
+ if isinstance(column, SharedEmbeddingColumn):
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._shared_state_manager)
+ else:
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._state_manager)
+ num_elements = column.variable_shape.num_elements()
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ output_tensors.append(tensor)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = tensor
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
@@ -601,6 +603,7 @@ class LinearModel(Layer):
shape=[self._units],
initializer=init_ops.zeros_initializer(),
trainable=self.trainable,
+ use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
@@ -627,36 +630,41 @@ class LinearModel(Layer):
if not isinstance(features, dict):
raise ValueError('We expected a dictionary here. Instead we got: ',
features)
- transformation_cache = FeatureTransformationCache(features)
- weighted_sums = []
- for column in self._feature_columns:
- with ops.name_scope(column.name):
- # All the weights used in the linear model are owned by the state
- # manager associated with this Linear Model.
- weight_var = self._state_manager.get_variable(column, 'weights')
-
- # The embedding weights for the SharedEmbeddingColumn are owned by
- # the shared_state_manager and so we need to pass that in while
- # creating the weighted sum. For all other columns, the state is owned
- # by the Linear Model's state manager.
- if isinstance(column, SharedEmbeddingColumn):
- state_manager = self._shared_state_manager
- else:
- state_manager = self._state_manager
- weighted_sum = _create_weighted_sum(
- column=column,
- transformation_cache=transformation_cache,
- state_manager=state_manager,
- sparse_combiner=self._sparse_combiner,
- weight_var=weight_var)
- weighted_sums.append(weighted_sum)
-
- _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
- predictions_no_bias = math_ops.add_n(
- weighted_sums, name='weighted_sum_no_bias')
- predictions = nn_ops.bias_add(
- predictions_no_bias, self._bias_variable, name='weighted_sum')
- return predictions
+ with ops.name_scope(self.name):
+ transformation_cache = FeatureTransformationCache(features)
+ weighted_sums = []
+ for column in self._feature_columns:
+ with ops.name_scope(column.name):
+ # All the weights used in the linear model are owned by the state
+ # manager associated with this Linear Model.
+ weight_var = self._state_manager.get_variable(column, 'weights')
+
+ # The embedding weights for the SharedEmbeddingColumn are owned by
+ # the shared_state_manager and so we need to pass that in while
+ # creating the weighted sum. For all other columns, the state is owned
+ # by the Linear Model's state manager.
+ if isinstance(column, SharedEmbeddingColumn):
+ state_manager = self._shared_state_manager
+ else:
+ state_manager = self._state_manager
+ weighted_sum = _create_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ sparse_combiner=self._sparse_combiner,
+ weight_var=weight_var)
+ weighted_sums.append(weighted_sum)
+
+ _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
+ predictions_no_bias = math_ops.add_n(
+ weighted_sums, name='weighted_sum_no_bias')
+ predictions = nn_ops.bias_add(
+ predictions_no_bias, self._bias_variable, name='weighted_sum')
+ return predictions
+
+ @property
+ def bias_variable(self):
+ return self._bias_variable
def _transform_features(features, feature_columns, state_manager):
@@ -2605,6 +2613,7 @@ class SharedEmbeddingStateManager(Layer):
dtype=dtype,
trainable=self.trainable and trainable,
initializer=initializer,
+ use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
@@ -3279,6 +3288,7 @@ def _safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
+ # TODO(rohanj): Look into removing this convert_to_tensor call.
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
]
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index a13a5010e1..d3787146ed 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -5170,8 +5170,8 @@ class WeightedCategoricalColumnTest(test.TestCase):
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError, r'Dimensions.*are not compatible'):
+ with self.assertRaisesRegexp(ValueError,
+ r'Dimensions.*are not compatible'):
model = fc.LinearModel((column,))
model({
'ids':