aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-27 20:52:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 20:56:36 -0700
commit370d385c3029a7972ba201c8303942b30f09521c (patch)
treeaba3c0539b9c56bec5e09748fbeb12c53e095366 /tensorflow/python/feature_column
parent986193d79e00f1780fb3278ed890a72f7285f66e (diff)
Creating a LinearModel that works with V2 feature columns.
In subsequent change I'll change canned estimators to support FeatureColumn V2 and use this LinearModel. PiperOrigin-RevId: 214882241
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py574
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py1862
2 files changed, 507 insertions, 1929 deletions
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 538641c251..a8d5bfb437 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -136,14 +136,11 @@ import six
from tensorflow.python.eager import context
-from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine.base_layer import Layer
-from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -153,7 +150,6 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
@@ -245,28 +241,19 @@ class StateManager(object):
raise NotImplementedError('StateManager.get_resource')
-class _InputLayerStateManager(StateManager):
- """Manages the state of InputLayer."""
+class _StateManagerImpl(StateManager):
+ """Manages the state of FeatureLayer and LinearModel."""
- def __init__(self, layer, feature_columns, trainable):
- """Creates an _InputLayerStateManager object.
+ def __init__(self, layer, trainable):
+ """Creates an _StateManagerImpl object.
Args:
layer: The input layer this state manager is associated with.
- feature_columns: List of feature columns for the input layer
trainable: Whether by default, variables created are trainable or not.
"""
self._trainable = trainable
self._layer = layer
- self._cols_to_vars_map = {}
- self._cols_to_names_map = {}
- for column in sorted(feature_columns, key=lambda x: x.name):
- self._cols_to_vars_map[column] = {}
- base_name = column.name
- if isinstance(column, SharedEmbeddingColumn):
- base_name = column.shared_collection_name
- with variable_scope.variable_scope(base_name) as vs:
- self._cols_to_names_map[column] = _strip_leading_slashes(vs.name)
+ self._cols_to_vars_map = collections.defaultdict(lambda: {})
def create_variable(self,
feature_column,
@@ -277,19 +264,19 @@ class _InputLayerStateManager(StateManager):
initializer=None):
if name in self._cols_to_vars_map[feature_column]:
raise ValueError('Variable already exists.')
- with variable_scope.variable_scope(self._cols_to_names_map[feature_column]):
- var = self._layer.add_variable(
- name=name,
- shape=shape,
- dtype=dtype,
- initializer=initializer,
- trainable=self._trainable and trainable,
- # TODO(rohanj): Get rid of this hack once we have a mechanism for
- # specifying a default partitioner for an entire layer. In that case,
- # the default getter for Layers should work.
- getter=variable_scope.get_variable)
- self._cols_to_vars_map[feature_column][name] = var
- return var
+
+ var = self._layer.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ trainable=self._trainable and trainable,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._cols_to_vars_map[feature_column][name] = var
+ return var
def get_variable(self, feature_column, name):
if name in self._cols_to_vars_map[feature_column]:
@@ -313,12 +300,15 @@ class FeatureLayer(Layer):
keywords_embedded = embedding_column(
categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
columns = [price, keywords_embedded, ...]
- features = tf.parse_example(..., features=make_parse_example_spec(columns))
feature_layer = FeatureLayer(columns)
+
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
dense_tensor = feature_layer(features)
for units in [128, 64, 32]:
dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
- prediction = tf.layers.dense(dense_tensor, 1)."""
+ prediction = tf.layers.dense(dense_tensor, 1).
+ ```
+ """
def __init__(self,
feature_columns,
@@ -375,8 +365,7 @@ class FeatureLayer(Layer):
super(FeatureLayer, self).__init__(name=name, trainable=trainable, **kwargs)
self._feature_columns = _normalize_feature_columns(feature_columns)
- self._state_manager = _InputLayerStateManager(self, self._feature_columns,
- self.trainable)
+ self._state_manager = _StateManagerImpl(self, self.trainable)
self._shared_state_manager = shared_state_manager
for column in sorted(self._feature_columns, key=lambda x: x.name):
if not isinstance(column, DenseColumn):
@@ -395,7 +384,8 @@ class FeatureLayer(Layer):
column.create_state(self._shared_state_manager)
else:
with variable_scope.variable_scope(None, default_name=self.name):
- column.create_state(self._state_manager)
+ with variable_scope.variable_scope(None, default_name=column.name):
+ column.create_state(self._state_manager)
super(FeatureLayer, self).build(None)
def call(self, features, cols_to_output_tensors=None):
@@ -448,20 +438,18 @@ class FeatureLayer(Layer):
return (input_shape[0], total_elements)
-def linear_model(features,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- """Returns a linear prediction `Tensor` based on given `feature_columns`.
+def _strip_leading_slashes(name):
+ return name.rsplit('/', 1)[-1]
+
+
+class LinearModel(Layer):
+ """Produces a linear prediction `Tensor` based on given `feature_columns`.
- This function generates a weighted sum based on output dimension `units`.
+ This layer generates a weighted sum based on output dimension `units`.
Weighted sum refers to logits in classification problems. It refers to the
prediction itself for linear regression problems.
- Note on supported columns: `linear_model` treats categorical columns as
+ Note on supported columns: `LinearModel` treats categorical columns as
`indicator_column`s. To be specific, assume the input as `SparseTensor` looks
like:
@@ -486,308 +474,189 @@ def linear_model(features,
keywords = categorical_column_with_hash_bucket("keywords", 10K)
keywords_price = crossed_column('keywords', price_buckets, ...)
columns = [price_buckets, keywords, keywords_price ...]
+ linear_model = LinearModel(columns)
+
features = tf.parse_example(..., features=make_parse_example_spec(columns))
- prediction = linear_model(features, columns)
+ prediction = linear_model(features)
```
-
- Args:
- features: A mapping from key to tensors. `_FeatureColumn`s look up via these
- keys. For example `numeric_column('price')` will look at 'price' key in
- this dict. Values are `Tensor` or `SparseTensor` depending on
- corresponding `_FeatureColumn`.
- feature_columns: An iterable containing the FeatureColumns to use as inputs
- to your model. All items should be instances of classes derived from
- `_FeatureColumn`s.
- units: An integer, dimensionality of the output space. Default value is 1.
- sparse_combiner: A string specifying how to reduce if a categorical column
- is multivalent. Except `numeric_column`, almost all columns passed to
- `linear_model` are considered as categorical columns. It combines each
- categorical column independently. Currently "mean", "sqrtn" and "sum" are
- supported, with "sum" the default for linear model. "sqrtn" often achieves
- good accuracy, in particular with bag-of-words columns.
- * "sum": do not normalize features in the column
- * "mean": do l1 normalization on features in the column
- * "sqrtn": do l2 normalization on features in the column
- For example, for two features represented as the categorical columns:
-
- ```python
- # Feature 1
-
- shape = [2, 2]
- {
- [0, 0]: "a"
- [0, 1]: "b"
- [1, 0]: "c"
- }
-
- # Feature 2
-
- shape = [2, 3]
- {
- [0, 0]: "d"
- [1, 0]: "e"
- [1, 1]: "f"
- [1, 2]: "g"
- }
- ```
- with `sparse_combiner` as "mean", the linear model outputs conceptly are:
- ```
- y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
- y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
- ```
- where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
- assigned to the presence of `x` in the input features.
- weight_collections: A list of collection names to which the Variable will be
- added. Note that, variables will also be added to collections
- `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
- trainable: If `True` also add the variable to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- cols_to_vars: If not `None`, must be a dictionary that will be filled with a
- mapping from `_FeatureColumn` to associated list of `Variable`s. For
- example, after the call, we might have cols_to_vars = {
- _NumericColumn(
- key='numeric_feature1', shape=(1,):
- [<tf.Variable 'linear_model/price2/weights:0' shape=(1, 1)>],
- 'bias': [<tf.Variable 'linear_model/bias_weights:0' shape=(1,)>],
- _NumericColumn(
- key='numeric_feature2', shape=(2,)):
- [<tf.Variable 'linear_model/price1/weights:0' shape=(2, 1)>]}
- If a column creates no variables, its value will be an empty list. Note
- that cols_to_vars will also contain a string key 'bias' that maps to a
- list of Variables.
-
- Returns:
- A `Tensor` which represents predictions/logits of a linear model. Its shape
- is (batch_size, units) and its dtype is `float32`.
-
- Raises:
- ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
- nor `_CategoricalColumn`.
- """
- with variable_scope.variable_scope(None, 'linear_model') as vs:
- model_name = _strip_leading_slashes(vs.name)
- linear_model_layer = _LinearModel(
- feature_columns=feature_columns,
- units=units,
- sparse_combiner=sparse_combiner,
- weight_collections=weight_collections,
- trainable=trainable,
- name=model_name)
- retval = linear_model_layer(features) # pylint: disable=not-callable
- if cols_to_vars is not None:
- cols_to_vars.update(linear_model_layer.cols_to_vars())
- return retval
-
-
-def _add_to_collections(var, weight_collections):
- """Adds a var to the list of weight_collections provided.
-
- Handles the case for partitioned and non-partitioned variables.
-
- Args:
- var: A variable or Partitioned Variable.
- weight_collections: List of collections to add variable to.
- """
- for weight_collection in weight_collections:
- # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
- if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
- continue
- # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
- # so that we don't have to do this check.
- if isinstance(var, variables.PartitionedVariable):
- for constituent_var in list(var):
- ops.add_to_collection(weight_collection, constituent_var)
- else:
- ops.add_to_collection(weight_collection, var)
-
-
-class _FCLinearWrapper(base.Layer):
- """Wraps a _FeatureColumn in a layer for use in a linear model.
-
- See `linear_model` above.
"""
def __init__(self,
- feature_column,
+ feature_columns,
units=1,
sparse_combiner='sum',
- weight_collections=None,
trainable=True,
name=None,
+ shared_state_manager=None,
**kwargs):
- super(_FCLinearWrapper, self).__init__(
- trainable=trainable, name=name, **kwargs)
- self._feature_column = feature_column
- self._units = units
- self._sparse_combiner = sparse_combiner
- self._weight_collections = weight_collections
+ """Constructs a LinearModel.
- def build(self, _):
- if isinstance(self._feature_column, fc_old._CategoricalColumn): # pylint: disable=protected-access
- weight = self.add_variable(
- name='weights',
- shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- else:
- num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access
- weight = self.add_variable(
- name='weights',
- shape=[num_elements, self._units],
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- _add_to_collections(weight, self._weight_collections)
- self._weight_var = weight
- self.built = True
-
- def call(self, builder):
- weighted_sum = fc_old._create_weighted_sum( # pylint: disable=protected-access
- column=self._feature_column,
- builder=builder,
- units=self._units,
- sparse_combiner=self._sparse_combiner,
- weight_collections=self._weight_collections,
- trainable=self.trainable,
- weight_var=self._weight_var)
- return weighted_sum
+ Args:
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `_FeatureColumn`s.
+ units: An integer, dimensionality of the output space. Default value is 1.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. Except `numeric_column`, almost all columns passed to
+ `linear_model` are considered as categorical columns. It combines each
+ categorical column independently. Currently "mean", "sqrtn" and "sum"
+ are supported, with "sum" the default for linear model. "sqrtn" often
+ achieves good accuracy, in particular with bag-of-words columns.
+ * "sum": do not normalize features in the column
+ * "mean": do l1 normalization on features in the column
+ * "sqrtn": do l2 normalization on features in the column
+ For example, for two features represented as the categorical columns:
+
+ ```python
+ # Feature 1
+
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [0, 1]: "b"
+ [1, 0]: "c"
+ }
+
+ # Feature 2
+
+ shape = [2, 3]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ [1, 1]: "f"
+ [1, 2]: "g"
+ }
+ ```
+
+ with `sparse_combiner` as "mean", the linear model outputs conceptly are
+ ```
+ y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
+ y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
+ ```
+ where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
+ assigned to the presence of `x` in the input features.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: Name to give to the Linear Model. All variables and ops created will
+ be scoped by this name.
+ shared_state_manager: SharedEmbeddingStateManager that manages the state
+ of SharedEmbeddingColumns. For more info, look at `FeatureLayer`.
+ **kwargs: Keyword arguments to construct a layer.
+ Raises:
+ ValueError: if an item in `feature_columns` is neither a `DenseColumn`
+ nor `CategoricalColumn`.
+ """
+ super(LinearModel, self).__init__(name=name, trainable=trainable, **kwargs)
-class _BiasLayer(base.Layer):
- """A layer for the bias term.
- """
+ self._feature_columns = _normalize_feature_columns(feature_columns)
+ self._feature_columns = sorted(self._feature_columns, key=lambda x: x.name)
+ for column in self._feature_columns:
+ if not isinstance(column, (DenseColumn, CategoricalColumn)):
+ raise ValueError(
+ 'Items of feature_columns must be either a '
+ 'DenseColumn or CategoricalColumn. Given: {}'.format(column))
- def __init__(self,
- units=1,
- trainable=True,
- weight_collections=None,
- name=None,
- **kwargs):
- super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs)
self._units = units
- self._weight_collections = weight_collections
-
- def build(self, _):
- self._bias_variable = self.add_variable(
- 'bias_weights',
- shape=[self._units],
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- _add_to_collections(self._bias_variable, self._weight_collections)
- self.built = True
-
- def call(self, _):
- return self._bias_variable
+ self._sparse_combiner = sparse_combiner
+ self._state_manager = _StateManagerImpl(self, self.trainable)
+ self._shared_state_manager = shared_state_manager
+ self._bias_variable = None
-def _get_expanded_variable_list(var_list):
- returned_list = []
- for variable in var_list:
- if (isinstance(variable, variables.Variable) or
- resource_variable_ops.is_resource_variable(variable)):
- returned_list.append(variable) # Single variable case.
- else: # Must be a PartitionedVariable, so convert into a list.
- returned_list.extend(list(variable))
- return returned_list
+ def build(self, _):
+ # Create state for shared embedding columns.
+ for column in self._feature_columns:
+ if isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._shared_state_manager)
+ # We need variable scopes for now because we want the variable partitioning
+ # information to percolate down. We also use _pure_variable_scope's here
+ # since we want to open up a name_scope in the `call` method while creating
+ # the ops.
+ with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ for column in self._feature_columns:
+ with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
+ # Create the state for each feature column
+ if not isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._state_manager)
+
+ # Create a weight variable for each column.
+ if isinstance(column, CategoricalColumn):
+ first_dim = column.num_buckets
+ else:
+ first_dim = column.variable_shape.num_elements()
+ self._state_manager.create_variable(
+ column,
+ name='weights',
+ dtype=dtypes.float32,
+ shape=(first_dim, self._units),
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+
+ # Create a bias variable.
+ self._bias_variable = self.add_variable(
+ name='bias_weights',
+ dtype=dtypes.float32,
+ shape=[self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
-def _strip_leading_slashes(name):
- return name.rsplit('/', 1)[-1]
+ super(LinearModel, self).build(None)
+ def call(self, features):
+ """Returns a `Tensor` the represents the predictions of a linear model.
-class _LinearModel(training.Model):
- """Creates a linear model using feature columns.
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values are `Tensor` or `SparseTensor` depending on
+ corresponding `_FeatureColumn`.
- See `linear_model` for details.
- """
+ Returns:
+ A `Tensor` which represents predictions/logits of a linear model. Its
+ shape is (batch_size, units) and its dtype is `float32`.
- def __init__(self,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- name=None,
- **kwargs):
- super(_LinearModel, self).__init__(name=name, **kwargs)
- self._feature_columns = fc_old._normalize_feature_columns( # pylint: disable=protected-access
- feature_columns)
- self._weight_collections = list(weight_collections or [])
- if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
- if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
-
- column_layers = {}
- for column in sorted(self._feature_columns, key=lambda x: x.name):
- with variable_scope.variable_scope(
- None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access
- # Having the fully expressed variable scope name ends up doubly
- # expressing the outer scope (scope with which this method was called)
- # in the name of the variable that would get created.
- column_name = _strip_leading_slashes(vs.name)
- column_layer = _FCLinearWrapper(column, units, sparse_combiner,
- self._weight_collections, trainable,
- column_name, **kwargs)
- column_layers[column_name] = column_layer
- self._column_layers = self._add_layers(column_layers)
- self._bias_layer = _BiasLayer(
- units=units,
- trainable=trainable,
- weight_collections=self._weight_collections,
- name='bias_layer',
- **kwargs)
- self._cols_to_vars = {}
-
- def cols_to_vars(self):
- """Returns a dict mapping _FeatureColumns to variables.
-
- See `linear_model` for more information.
- This is not populated till `call` is called i.e. layer is built.
+ Raises:
+ ValueError: If features are not a dictionary.
"""
- return self._cols_to_vars
-
- def call(self, features):
- with variable_scope.variable_scope(self.name):
- for column in self._feature_columns:
- if not isinstance(
- column,
- (
- fc_old._DenseColumn, # pylint: disable=protected-access
- fc_old._CategoricalColumn)): # pylint: disable=protected-access
- raise ValueError(
- 'Items of feature_columns must be either a '
- '_DenseColumn or _CategoricalColumn. Given: {}'.format(column))
- weighted_sums = []
- ordered_columns = []
- builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
- for layer in sorted(self._column_layers.values(), key=lambda x: x.name):
- column = layer._feature_column # pylint: disable=protected-access
- ordered_columns.append(column)
- weighted_sum = layer(builder)
+ if not isinstance(features, dict):
+ raise ValueError('We expected a dictionary here. Instead we got: ',
+ features)
+ transformation_cache = FeatureTransformationCache(features)
+ weighted_sums = []
+ for column in self._feature_columns:
+ with ops.name_scope(column.name):
+ # All the weights used in the linear model are owned by the state
+ # manager associated with this Linear Model.
+ weight_var = self._state_manager.get_variable(column, 'weights')
+
+ # The embedding weights for the SharedEmbeddingColumn are owned by
+ # the shared_state_manager and so we need to pass that in while
+ # creating the weighted sum. For all other columns, the state is owned
+ # by the Linear Model's state manager.
+ if isinstance(column, SharedEmbeddingColumn):
+ state_manager = self._shared_state_manager
+ else:
+ state_manager = self._state_manager
+ weighted_sum = _create_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ sparse_combiner=self._sparse_combiner,
+ weight_var=weight_var)
weighted_sums.append(weighted_sum)
- self._cols_to_vars[column] = ops.get_collection(
- ops.GraphKeys.GLOBAL_VARIABLES, scope=layer.scope_name)
-
- _verify_static_batch_size_equality(weighted_sums, ordered_columns)
- predictions_no_bias = math_ops.add_n(
- weighted_sums, name='weighted_sum_no_bias')
- predictions = nn_ops.bias_add(
- predictions_no_bias,
- self._bias_layer( # pylint: disable=not-callable
- builder,
- scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable
- name='weighted_sum')
- bias = self._bias_layer.variables[0]
- self._cols_to_vars['bias'] = _get_expanded_variable_list([bias])
- return predictions
- def _add_layers(self, layers):
- # "Magic" required for keras.Model classes to track all the variables in
- # a list of layers.Layer objects.
- # TODO(ashankar): Figure out API so user code doesn't have to do this.
- for name, layer in layers.items():
- setattr(self, 'layer-%s' % name, layer)
- return layers
+ _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
+ predictions_no_bias = math_ops.add_n(
+ weighted_sums, name='weighted_sum_no_bias')
+ predictions = nn_ops.bias_add(
+ predictions_no_bias, self._bias_variable, name='weighted_sum')
+ return predictions
def _transform_features(features, feature_columns, state_manager):
@@ -2053,58 +1922,32 @@ def is_feature_column_v2(feature_columns):
return True
-def _create_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_weighted_sum(column, transformation_cache, state_manager,
+ sparse_combiner, weight_var):
"""Creates a weighted sum for a dense/categorical column for linear_model."""
if isinstance(column, CategoricalColumn):
return _create_categorical_column_weighted_sum(
column=column,
transformation_cache=transformation_cache,
state_manager=state_manager,
- units=units,
sparse_combiner=sparse_combiner,
- weight_collections=weight_collections,
- trainable=trainable,
weight_var=weight_var)
else:
return _create_dense_column_weighted_sum(
column=column,
transformation_cache=transformation_cache,
state_manager=state_manager,
- units=units,
- weight_collections=weight_collections,
- trainable=trainable,
weight_var=weight_var)
-def _create_dense_column_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_dense_column_weighted_sum(column, transformation_cache,
+ state_manager, weight_var):
"""Create a weighted sum of a dense column for linear_model."""
tensor = column.get_dense_tensor(transformation_cache, state_manager)
num_elements = column.variable_shape.num_elements()
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- if weight_var is not None:
- weight = weight_var
- else:
- weight = variable_scope.get_variable(
- name='weights',
- shape=[num_elements, units],
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
- return math_ops.matmul(tensor, weight, name='weighted_sum')
+ return math_ops.matmul(tensor, weight_var, name='weighted_sum')
class CategoricalColumn(FeatureColumn):
@@ -2145,14 +1988,8 @@ class CategoricalColumn(FeatureColumn):
pass
-def _create_categorical_column_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_categorical_column_weighted_sum(
+ column, transformation_cache, state_manager, sparse_combiner, weight_var):
# pylint: disable=g-doc-return-or-yield,g-doc-args
"""Create a weighted sum of a categorical column for linear_model.
@@ -2191,17 +2028,8 @@ def _create_categorical_column_weighted_sum(column,
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
- if weight_var is not None:
- weight = weight_var
- else:
- weight = variable_scope.get_variable(
- name='weights',
- shape=(column.num_buckets, units),
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
return _safe_embedding_lookup_sparse(
- weight,
+ weight_var,
id_tensor,
sparse_weights=weight_tensor,
combiner=sparse_combiner,
@@ -2836,6 +2664,10 @@ class SharedEmbeddingColumn(
def create_state(self, state_manager):
"""Creates the shared embedding lookup variable."""
+ if not isinstance(state_manager, SharedEmbeddingStateManager):
+ raise ValueError('Expected state_manager to be of type '
+ 'SharedEmbeddingStateManager. Obtained type: {}'.format(
+ type(state_manager)))
embedding_shape = (self.categorical_column.num_buckets, self.dimension)
state_manager.create_variable(
name=self.shared_collection_name,
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 2970431167..a13a5010e1 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -31,9 +31,7 @@ from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.estimator.inputs import numpy_io
-from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc
-from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
from tensorflow.python.feature_column.feature_column_v2 import FeatureLayer
@@ -48,7 +46,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
@@ -360,26 +357,12 @@ class NumericColumnTest(test.TestCase):
self.assertEqual(a.default_value, ((3., 2.),))
def test_linear_model(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.]], price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price_var.assign([[10.]]))
- self.assertAllClose([[10.], [50.]], predictions.eval())
-
- def test_keras_linear_model(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.]], price_var.eval())
@@ -564,13 +547,13 @@ class BucketizedColumnTest(test.TestCase):
def test_linear_model_one_input_value(self):
"""Tests linear_model() for input with shape=[1]."""
- price = fc_old.numeric_column('price', shape=[1])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1.], [1.], [5.], [6.]]}
- predictions = fc.linear_model(features, [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ model = fc.LinearModel([bucketized_price])
+ predictions = model(features)
+ bucketized_price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight variable per bucket, all initialized to zero.
@@ -589,13 +572,13 @@ class BucketizedColumnTest(test.TestCase):
def test_linear_model_two_input_values(self):
"""Tests linear_model() for input with shape=[2]."""
- price = fc_old.numeric_column('price', shape=[2])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1., 1.], [5., 6.]]}
- predictions = fc.linear_model(features, [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ model = fc.LinearModel([bucketized_price])
+ predictions = model(features)
+ bucketized_price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight per bucket per input column, all initialized to zero.
@@ -616,62 +599,6 @@ class BucketizedColumnTest(test.TestCase):
sess.run(bias.assign([1.]))
self.assertAllClose([[81.], [141.]], predictions.eval())
- def test_keras_linear_model_one_input_value(self):
- """Tests _LinearModel for input with shape=[1]."""
- price = fc_old.numeric_column('price', shape=[1])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
- with ops.Graph().as_default():
- features = {'price': [[-1.], [1.], [5.], [6.]]}
- predictions = get_keras_linear_model_predictions(features,
- [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- # One weight variable per bucket, all initialized to zero.
- self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
- bucketized_price_var.eval())
- self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
- sess.run(
- bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
- # price -1. is in the 0th bucket, whose weight is 10.
- # price 1. is in the 1st bucket, whose weight is 20.
- # price 5. is in the 3rd bucket, whose weight is 40.
- # price 6. is in the 4th bucket, whose weight is 50.
- self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
- sess.run(bias.assign([1.]))
- self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
-
- def test_keras_linear_model_two_input_values(self):
- """Tests _LinearModel for input with shape=[2]."""
- price = fc_old.numeric_column('price', shape=[2])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
- with ops.Graph().as_default():
- features = {'price': [[-1., 1.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(features,
- [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- # One weight per bucket per input column, all initialized to zero.
- self.assertAllClose(
- [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
- bucketized_price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(
- bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],
- [60.], [70.], [80.], [90.], [100.]]))
- # 1st example:
- # price -1. is in the 0th bucket, whose weight is 10.
- # price 1. is in the 6th bucket, whose weight is 70.
- # 2nd example:
- # price 5. is in the 3rd bucket, whose weight is 40.
- # price 6. is in the 9th bucket, whose weight is 100.
- self.assertAllClose([[80.], [140.]], predictions.eval())
- sess.run(bias.assign([1.]))
- self.assertAllClose([[81.], [141.]], predictions.eval())
-
class HashedCategoricalColumnTest(test.TestCase):
@@ -852,39 +779,18 @@ class HashedCategoricalColumnTest(test.TestCase):
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 3: wire_var[3] = 4
- # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
- self.assertAllClose(((4.,), (6.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
- self.assertEqual(4, wire_column._num_buckets)
+ wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -1103,93 +1009,12 @@ class CrossedColumnTest(test.TestCase):
Uses data from test_get_sparse_tesnsors_simple.
"""
- a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
- b = fc_old.bucketized_column(a, boundaries=(0, 1))
- crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'a': constant_op.constant(((-1., .5), (.5, 1.))),
- 'c': sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=['cA', 'cB', 'cC'],
- dense_shape=(2, 2)),
- }, (crossed,))
- bias = get_linear_model_bias()
- crossed_var = get_linear_model_column_var(crossed)
- with _initialized_session() as sess:
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(
- ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
- # Expected ids after cross = (1, 0, 1, 3, 4, 2)
- self.assertAllClose(((3.,), (14.,)), predictions.eval())
- sess.run(bias.assign((.1,)))
- self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
-
- def test_linear_model_with_weights(self):
-
- class _TestColumnWithWeights(fc_old._CategoricalColumn):
- """Produces sparse IDs and sparse weights."""
-
- @property
- def name(self):
- return 'test_column'
-
- @property
- def _parse_example_spec(self):
- return {
- self.name: parsing_ops.VarLenFeature(dtypes.int32),
- '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
- dtypes.float32),
- }
-
- @property
- def _num_buckets(self):
- return 5
-
- def _transform_feature(self, inputs):
- return (inputs.get(self.name),
- inputs.get('{}_weights'.format(self.name)))
-
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
- """Populates both id_tensor and weight_tensor."""
- ids_and_weights = inputs.get(self)
- return fc_old._CategoricalColumn.IdWeightPair(
- id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
-
- t = _TestColumnWithWeights()
- crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError,
- 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
- fc.linear_model({
- t.name: sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=[0, 1, 2],
- dense_shape=(2, 2)),
- '{}_weights'.format(t.name): sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=[1., 10., 2.],
- dense_shape=(2, 2)),
- 'c': sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=['cA', 'cB', 'cC'],
- dense_shape=(2, 2)),
- }, (crossed,))
-
- def test_keras_linear_model(self):
- """Tests _LinearModel.
-
- Uses data from test_get_sparse_tesnsors_simple.
- """
- a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
- b = fc_old.bucketized_column(a, boundaries=(0, 1))
- crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((crossed,))
+ predictions = model({
'a':
constant_op.constant(((-1., .5), (.5, 1.))),
'c':
@@ -1197,13 +1022,12 @@ class CrossedColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=['cA', 'cB', 'cC'],
dense_shape=(2, 2)),
- }, (crossed,))
- bias = get_linear_model_bias()
- crossed_var = get_linear_model_column_var(crossed)
+ })
+ crossed_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
- crossed_var.eval())
+ self.assertAllClose(
+ ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
self.assertAllClose(((0.,), (0.,)), predictions.eval())
sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
# Expected ids after cross = (1, 0, 1, 3, 4, 2)
@@ -1211,9 +1035,9 @@ class CrossedColumnTest(test.TestCase):
sess.run(bias.assign((.1,)))
self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
- def test_keras_linear_model_with_weights(self):
+ def test_linear_model_with_weights(self):
- class _TestColumnWithWeights(fc_old._CategoricalColumn):
+ class _TestColumnWithWeights(fc.CategoricalColumn):
"""Produces sparse IDs and sparse weights."""
@property
@@ -1221,38 +1045,36 @@ class CrossedColumnTest(test.TestCase):
return 'test_column'
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
return {
- self.name:
- parsing_ops.VarLenFeature(dtypes.int32),
- '{}_weights'.format(self.name):
- parsing_ops.VarLenFeature(dtypes.float32),
- }
+ self.name: parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
+ dtypes.float32),
+ }
@property
- def _num_buckets(self):
+ def num_buckets(self):
return 5
- def _transform_feature(self, inputs):
- return (inputs.get(self.name),
- inputs.get('{}_weights'.format(self.name)))
+ def transform_feature(self, transformation_cache, state_manager):
+ return (transformation_cache.get(self.name, state_manager),
+ transformation_cache.get('{}_weights'.format(self.name),
+ state_manager))
- def _get_sparse_tensors(self,
- inputs,
- weight_collections=None,
- trainable=None):
+ def get_sparse_tensors(self, transformation_cache, state_manager):
"""Populates both id_tensor and weight_tensor."""
- ids_and_weights = inputs.get(self)
- return fc_old._CategoricalColumn.IdWeightPair(
+ ids_and_weights = transformation_cache.get(self, state_manager)
+ return fc.CategoricalColumn.IdWeightPair(
id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
t = _TestColumnWithWeights()
- crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError,
'crossed_column does not support weight_tensor.*{}'.format(t.name)):
- get_keras_linear_model_predictions({
+ model = fc.LinearModel((crossed,))
+ model({
t.name:
sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (1, 1)),
@@ -1268,37 +1090,7 @@ class CrossedColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=['cA', 'cB', 'cC'],
dense_shape=(2, 2)),
- }, (crossed,))
-
-
-def get_linear_model_bias(name='linear_model'):
- with variable_scope.variable_scope(name, reuse=True):
- return variable_scope.get_variable('bias_weights')
-
-
-def get_linear_model_column_var(column, name='linear_model'):
- return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
- name + '/' + column.name)[0]
-
-
-def get_keras_linear_model_predictions(features,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- keras_linear_model = _LinearModel(
- feature_columns,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- name='linear_model')
- retval = keras_linear_model(features) # pylint: disable=not-callable
- if cols_to_vars is not None:
- cols_to_vars.update(keras_linear_model.cols_to_vars())
- return retval
+ })
class LinearModelTest(test.TestCase):
@@ -1306,56 +1098,50 @@ class LinearModelTest(test.TestCase):
def test_raises_if_empty_feature_columns(self):
with self.assertRaisesRegexp(ValueError,
'feature_columns must not be empty'):
- fc.linear_model(features={}, feature_columns=[])
+ fc.LinearModel(feature_columns=[])
def test_should_be_feature_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
- fc.linear_model(features={'a': [[0]]}, feature_columns='NotSupported')
+ with self.assertRaisesRegexp(ValueError, 'must be a FeatureColumn'):
+ fc.LinearModel(feature_columns='NotSupported')
def test_should_be_dense_or_categorical_column(self):
- class NotSupportedColumn(fc_old._FeatureColumn):
+ class NotSupportedColumn(fc.FeatureColumn):
@property
def name(self):
return 'NotSupportedColumn'
- def _transform_feature(self, cache):
+ def transform_feature(self, transformation_cache, state_manager):
pass
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
pass
with self.assertRaisesRegexp(
- ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
- fc.linear_model(
- features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+ ValueError, 'must be either a DenseColumn or CategoricalColumn'):
+ fc.LinearModel(feature_columns=[NotSupportedColumn()])
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegexp(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
+ fc.LinearModel(feature_columns={'a': fc.numeric_column('a')})
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
+ fc.LinearModel(
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])
def test_dense_bias(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
sess.run(price_var.assign([[10.]]))
@@ -1363,16 +1149,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[15.], [55.]], predictions.eval())
def test_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast])
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
@@ -1381,18 +1167,17 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_and_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- price = fc_old.numeric_column('price')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [wire_cast, price])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([wire_cast, price])
+ predictions = model(features)
+ price_var, wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
@@ -1402,38 +1187,36 @@ class LinearModelTest(test.TestCase):
def test_dense_and_sparse_column(self):
"""When the column is both dense and sparse, uses sparse tensors."""
- class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
+ class _DenseAndSparseColumn(fc.DenseColumn, fc.CategoricalColumn):
@property
def name(self):
return 'dense_and_sparse_column'
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
return {self.name: parsing_ops.VarLenFeature(self.dtype)}
- def _transform_feature(self, inputs):
- return inputs.get(self.name)
+ def transform_feature(self, transformation_cache, state_manager):
+ return transformation_cache.get(self.name, state_manager)
@property
- def _variable_shape(self):
+ def variable_shape(self):
raise ValueError('Should not use this method.')
- def _get_dense_tensor(self, inputs, weight_collections=None,
- trainable=None):
+ def get_dense_tensor(self, transformation_cache, state_manager):
raise ValueError('Should not use this method.')
@property
- def _num_buckets(self):
+ def num_buckets(self):
return 4
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
+ def get_sparse_tensors(self, transformation_cache, state_manager):
sp_tensor = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
values=[2, 0, 3],
dense_shape=[2, 2])
- return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
+ return fc.CategoricalColumn.IdWeightPair(sp_tensor, None)
dense_and_sparse_column = _DenseAndSparseColumn()
with ops.Graph().as_default():
@@ -1442,10 +1225,9 @@ class LinearModelTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {dense_and_sparse_column.name: sp_tensor}
- predictions = fc.linear_model(features, [dense_and_sparse_column])
- bias = get_linear_model_bias()
- dense_and_sparse_column_var = get_linear_model_column_var(
- dense_and_sparse_column)
+ model = fc.LinearModel([dense_and_sparse_column])
+ predictions = model(features)
+ dense_and_sparse_column_var, bias = model.variables
with _initialized_session() as sess:
sess.run(dense_and_sparse_column_var.assign(
[[10.], [100.], [1000.], [10000.]]))
@@ -1453,12 +1235,12 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_multi_output(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price], units=3)
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((1, 3)), price_var.eval())
@@ -1468,16 +1250,16 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_sparse_multi_output(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast], units=3)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast], units=3)
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
@@ -1490,18 +1272,19 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_dense_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- predictions = fc.linear_model(features, [price])
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, _ = model.variables
with _initialized_session() as sess:
self.assertAllClose([[0.], [0.]], price_var.eval())
sess.run(price_var.assign([[10.], [100.]]))
self.assertAllClose([[210.], [650.]], predictions.eval())
def test_sparse_multi_rank(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = array_ops.sparse_placeholder(dtypes.string)
wire_value = sparse_tensor.SparseTensorValue(
@@ -1509,8 +1292,9 @@ class LinearModelTest(test.TestCase):
indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
dense_shape=[2, 2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast])
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast])
+ predictions = model(features)
+ wire_cast_var, _ = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
self.assertAllClose(
@@ -1522,25 +1306,24 @@ class LinearModelTest(test.TestCase):
predictions.eval(feed_dict={wire_tensor: wire_value}))
def test_sparse_combiner(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(
- features, [wire_cast], sparse_combiner='mean')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast], sparse_combiner='mean')
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [5010.]], predictions.eval())
def test_sparse_combiner_with_negative_weights(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- wire_cast_weights = fc_old.weighted_categorical_column(wire_cast, 'weights')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast_weights = fc.weighted_categorical_column(wire_cast, 'weights')
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
@@ -1551,22 +1334,21 @@ class LinearModelTest(test.TestCase):
'wire_cast': wire_tensor,
'weights': constant_op.constant([[1., 1., -1.0]])
}
- predictions = fc.linear_model(
- features, [wire_cast_weights], sparse_combiner='sum')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast_weights], sparse_combiner='sum')
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [-9985.]], predictions.eval())
def test_dense_multi_dimension_multi_output(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- predictions = fc.linear_model(features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price], units=3)
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((2, 3)), price_var.eval())
@@ -1576,21 +1358,22 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
with self.assertRaisesRegexp(
Exception,
r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ model(features)
def test_dense_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
+ price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default():
features = {'price': [[[1., 2.]], [[5., 6.]]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.]], price_var.eval())
@@ -1599,17 +1382,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[210.], [650.]], predictions.eval())
def test_dense_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1., 2.], [5., 6.]],
'price2': [[3.], [4.]]
}
- predictions = fc.linear_model(features, [price1, price2])
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
+ price1_var, price2_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.]], price1_var.eval())
@@ -1620,115 +1402,55 @@ class LinearModelTest(test.TestCase):
sess.run(bias.assign([7.]))
self.assertAllClose([[3217.], [4657.]], predictions.eval())
- def test_fills_cols_to_vars(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- cols_to_vars = {}
- fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- self.assertAllEqual(cols_to_vars['bias'], [bias])
- self.assertAllEqual(cols_to_vars[price1], [price1_var])
- self.assertAllEqual(cols_to_vars[price2], [price2_var])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2', shape=3)
- with ops.Graph().as_default():
- features = {
- 'price1': [[1., 2.], [6., 7.]],
- 'price2': [[3., 4., 5.], [8., 9., 10.]]
- }
- cols_to_vars = {}
- with variable_scope.variable_scope(
- 'linear',
- partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
- fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
- with _initialized_session():
- self.assertEqual([0.], cols_to_vars['bias'][0].eval())
- # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
- self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
- # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
- # a [1, 1] Variable.
- self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
-
- def test_dense_collection(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- self.assertIn(bias, my_vars)
- self.assertIn(price_var, my_vars)
-
- def test_sparse_collection(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- fc.linear_model(
- features, [wire_cast], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, my_vars)
- self.assertIn(wire_cast_var, my_vars)
-
def test_dense_trainable_default(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ model(features)
+ price_var, bias = model.variables
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertIn(bias, trainable_vars)
self.assertIn(price_var, trainable_vars)
def test_sparse_trainable_default(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default() as g:
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
features = {'wire_cast': wire_tensor}
- fc.linear_model(features, [wire_cast])
+ model = fc.LinearModel([wire_cast])
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ wire_cast_var, bias = model.variables
self.assertIn(bias, trainable_vars)
self.assertIn(wire_cast_var, trainable_vars)
def test_dense_trainable_false(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price], trainable=False)
+ model = fc.LinearModel([price], trainable=False)
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars)
def test_sparse_trainable_false(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default() as g:
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
features = {'wire_cast': wire_tensor}
- fc.linear_model(features, [wire_cast], trainable=False)
+ model = fc.LinearModel([wire_cast], trainable=False)
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars)
def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
@@ -1736,15 +1458,15 @@ class LinearModelTest(test.TestCase):
sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
}
- fc.linear_model(
- features, [price_a, wire_cast, price_b],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
+ model = fc.LinearModel([price_a, wire_cast, price_b])
+ model(features)
+
+ my_vars = model.variables
self.assertIn('price_a', my_vars[0].name)
self.assertIn('price_b', my_vars[1].name)
self.assertIn('wire_cast', my_vars[2].name)
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
@@ -1752,17 +1474,45 @@ class LinearModelTest(test.TestCase):
sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
}
- fc.linear_model(
- features, [wire_cast, price_b, price_a],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
+ model = fc.LinearModel([wire_cast, price_b, price_a])
+ model(features)
+
+ my_vars = model.variables
self.assertIn('price_a', my_vars[0].name)
self.assertIn('price_b', my_vars[1].name)
self.assertIn('wire_cast', my_vars[2].name)
+ def test_variable_names(self):
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+
+ with ops.Graph().as_default():
+ model = fc.LinearModel(all_cols)
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ model(features)
+ variable_names = [var.name for var in model.variables]
+ self.assertItemsEqual([
+ 'linear_model/dense_feature_bucketized/weights:0',
+ 'linear_model/price1/weights:0',
+ 'linear_model/sparse_feature_embedding/embedding_weights:0',
+ 'linear_model/sparse_feature_embedding/weights:0',
+ 'linear_model/bias_weights:0',
+ ], variable_names)
+
def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1.], [5.], [7.]], # batchsize = 3
@@ -1771,12 +1521,13 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ model(features)
def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
@@ -1786,17 +1537,19 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.linear_model(features, [price1, price2, price3])
+ model = fc.LinearModel([price1, price2, price3])
+ model(features)
def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
'price2': [[3.], [4.]] # batchsize = 2
}
- predictions = fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
with _initialized_session() as sess:
with self.assertRaisesRegexp(errors.OpError,
'must have the same size and shape'):
@@ -1804,14 +1557,15 @@ class LinearModelTest(test.TestCase):
predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
}
- predictions = fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
with _initialized_session() as sess:
sess.run(
predictions,
@@ -1821,14 +1575,14 @@ class LinearModelTest(test.TestCase):
})
def test_with_numpy_input_fn(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
input_fn = numpy_io.numpy_input_fn(
@@ -1839,15 +1593,14 @@ class LinearModelTest(test.TestCase):
batch_size=2,
shuffle=False)
features = input_fn()
- net = fc.linear_model(features, [price_buckets, body_style])
+ model = fc.LinearModel([price_buckets, body_style])
+ net = model(features)
# self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ body_style_var, price_buckets_var, bias = model.variables
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1859,14 +1612,14 @@ class LinearModelTest(test.TestCase):
coord.join(threads)
def test_with_1d_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
# Provides 1-dim tensor and dense tensor.
@@ -1880,11 +1633,10 @@ class LinearModelTest(test.TestCase):
self.assertEqual(1, features['price'].shape.ndims)
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
- net = fc.linear_model(features, [price_buckets, body_style])
+ model = fc.LinearModel([price_buckets, body_style])
+ net = model(features)
with _initialized_session() as sess:
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ body_style_var, price_buckets_var, bias = model.variables
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1893,16 +1645,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- country = fc_old.categorical_column_with_vocabulary_list(
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
# Provides 1-dim tensor and dense tensor.
@@ -1921,10 +1673,9 @@ class LinearModelTest(test.TestCase):
dense_shape=(2,))
country_data = np.array(['US', 'CA'])
- net = fc.linear_model(features, [price_buckets, body_style, country])
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ model = fc.LinearModel([price_buckets, body_style, country])
+ net = model(features)
+ body_style_var, _, price_buckets_var, bias = model.variables
with _initialized_session() as sess:
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1940,7 +1691,7 @@ class LinearModelTest(test.TestCase):
}))
def test_with_rank_0_feature(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
features = {
'price': constant_op.constant(0),
}
@@ -1948,29 +1699,31 @@ class LinearModelTest(test.TestCase):
# Static rank 0 should fail
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ model(features)
# Dynamic rank 0 should fail
features = {
'price': array_ops.placeholder(dtypes.float32),
}
- net = fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ net = model(features)
self.assertEqual(1, net.shape[1])
with _initialized_session() as sess:
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
sess.run(net, feed_dict={features['price']: np.array(1)})
def test_multiple_linear_models(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features1 = {'price': [[1.], [5.]]}
features2 = {'price': [[2.], [10.]]}
- predictions1 = fc.linear_model(features1, [price])
- predictions2 = fc.linear_model(features2, [price])
- bias1 = get_linear_model_bias(name='linear_model')
- bias2 = get_linear_model_bias(name='linear_model_1')
- price_var1 = get_linear_model_column_var(price, name='linear_model')
- price_var2 = get_linear_model_column_var(price, name='linear_model_1')
+ model1 = fc.LinearModel([price])
+ model2 = fc.LinearModel([price])
+ predictions1 = model1(features1)
+ predictions2 = model2(features2)
+ price_var1, bias1 = model1.variables
+ price_var2, bias2 = model2.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias1.eval())
sess.run(price_var1.assign([[10.]]))
@@ -1982,664 +1735,6 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[25.], [105.]], predictions2.eval())
-class _LinearModelTest(test.TestCase):
-
- def test_raises_if_empty_feature_columns(self):
- with self.assertRaisesRegexp(ValueError,
- 'feature_columns must not be empty'):
- get_keras_linear_model_predictions(features={}, feature_columns=[])
-
- def test_should_be_feature_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]}, feature_columns='NotSupported')
-
- def test_should_be_dense_or_categorical_column(self):
-
- class NotSupportedColumn(fc_old._FeatureColumn):
-
- @property
- def name(self):
- return 'NotSupportedColumn'
-
- def _transform_feature(self, cache):
- pass
-
- @property
- def _parse_example_spec(self):
- pass
-
- with self.assertRaisesRegexp(
- ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
-
- def test_does_not_support_dict_columns(self):
- with self.assertRaisesRegexp(
- ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
-
- def test_raises_if_duplicate_name(self):
- with self.assertRaisesRegexp(
- ValueError, 'Duplicate feature column name found for columns'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
-
- def test_dense_bias(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- sess.run(price_var.assign([[10.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[15.], [55.]], predictions.eval())
-
- def test_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(features, [wire_cast])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [10015.]], predictions.eval())
-
- def test_dense_and_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features,
- [wire_cast, price])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- sess.run(price_var.assign([[10.]]))
- self.assertAllClose([[1015.], [10065.]], predictions.eval())
-
- def test_dense_and_sparse_column(self):
- """When the column is both dense and sparse, uses sparse tensors."""
-
- class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
-
- @property
- def name(self):
- return 'dense_and_sparse_column'
-
- @property
- def _parse_example_spec(self):
- return {self.name: parsing_ops.VarLenFeature(self.dtype)}
-
- def _transform_feature(self, inputs):
- return inputs.get(self.name)
-
- @property
- def _variable_shape(self):
- raise ValueError('Should not use this method.')
-
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None):
- raise ValueError('Should not use this method.')
-
- @property
- def _num_buckets(self):
- return 4
-
- def _get_sparse_tensors(self,
- inputs,
- weight_collections=None,
- trainable=None):
- sp_tensor = sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 0], [1, 1]],
- values=[2, 0, 3],
- dense_shape=[2, 2])
- return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
-
- dense_and_sparse_column = _DenseAndSparseColumn()
- with ops.Graph().as_default():
- sp_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {dense_and_sparse_column.name: sp_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [dense_and_sparse_column])
- bias = get_linear_model_bias()
- dense_and_sparse_column_var = get_linear_model_column_var(
- dense_and_sparse_column)
- with _initialized_session() as sess:
- sess.run(
- dense_and_sparse_column_var.assign([[10.], [100.], [1000.],
- [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [10015.]], predictions.eval())
-
- def test_dense_multi_output(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(
- features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((1, 3)), price_var.eval())
- sess.run(price_var.assign([[10., 100., 1000.]]))
- sess.run(bias.assign([5., 6., 7.]))
- self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
- predictions.eval())
-
- def test_sparse_multi_output(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [wire_cast], units=3)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
- sess.run(
- wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],
- [1000., 1100.,
- 1200.], [10000., 11000., 12000.]]))
- sess.run(bias.assign([5., 6., 7.]))
- self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
- predictions.eval())
-
- def test_dense_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1., 2.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([[0.], [0.]], price_var.eval())
- sess.run(price_var.assign([[10.], [100.]]))
- self.assertAllClose([[210.], [650.]], predictions.eval())
-
- def test_sparse_multi_rank(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = array_ops.sparse_placeholder(dtypes.string)
- wire_value = sparse_tensor.SparseTensorValue(
- values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
- indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
- dense_shape=[2, 2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(features, [wire_cast])
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
- self.assertAllClose(
- np.zeros((2, 1)),
- predictions.eval(feed_dict={wire_tensor: wire_value}))
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- self.assertAllClose(
- [[1010.], [11000.]],
- predictions.eval(feed_dict={wire_tensor: wire_value}))
-
- def test_sparse_combiner(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [wire_cast], sparse_combiner='mean')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [5010.]], predictions.eval())
-
- def test_dense_multi_dimension_multi_output(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1., 2.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(
- features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((2, 3)), price_var.eval())
- sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
- sess.run(bias.assign([2., 3., 4.]))
- self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
- predictions.eval())
-
- def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- with self.assertRaisesRegexp(
- Exception,
- r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- get_keras_linear_model_predictions(features, [price])
-
- def test_dense_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
- with ops.Graph().as_default():
- features = {'price': [[[1., 2.]], [[5., 6.]]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.]], price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price_var.assign([[10.], [100.]]))
- self.assertAllClose([[210.], [650.]], predictions.eval())
-
- def test_dense_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.]], price1_var.eval())
- self.assertAllClose([[0.]], price2_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price1_var.assign([[10.], [100.]]))
- sess.run(price2_var.assign([[1000.]]))
- sess.run(bias.assign([7.]))
- self.assertAllClose([[3217.], [4657.]], predictions.eval())
-
- def test_fills_cols_to_vars(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- cols_to_vars = {}
- get_keras_linear_model_predictions(
- features, [price1, price2], cols_to_vars=cols_to_vars)
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- self.assertAllEqual(cols_to_vars['bias'], [bias])
- self.assertAllEqual(cols_to_vars[price1], [price1_var])
- self.assertAllEqual(cols_to_vars[price2], [price2_var])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2', shape=3)
- with ops.Graph().as_default():
- features = {
- 'price1': [[1., 2.], [6., 7.]],
- 'price2': [[3., 4., 5.], [8., 9., 10.]]
- }
- cols_to_vars = {}
- with variable_scope.variable_scope(
- 'linear',
- partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
- get_keras_linear_model_predictions(
- features, [price1, price2], cols_to_vars=cols_to_vars)
- with _initialized_session():
- self.assertEqual([0.], cols_to_vars['bias'][0].eval())
- # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
- self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
- # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
- # a [1, 1] Variable.
- self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
-
- def test_dense_collection(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(
- features, [price], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- self.assertIn(bias, my_vars)
- self.assertIn(price_var, my_vars)
-
- def test_sparse_collection(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(
- features, [wire_cast], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, my_vars)
- self.assertIn(wire_cast_var, my_vars)
-
- def test_dense_trainable_default(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertIn(bias, trainable_vars)
- self.assertIn(price_var, trainable_vars)
-
- def test_sparse_trainable_default(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(features, [wire_cast])
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, trainable_vars)
- self.assertIn(wire_cast_var, trainable_vars)
-
- def test_dense_trainable_false(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(features, [price], trainable=False)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertEqual([], trainable_vars)
-
- def test_sparse_trainable_false(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(features, [wire_cast], trainable=False)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertEqual([], trainable_vars)
-
- def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- features = {
- 'price_a': [[1.]],
- 'price_b': [[3.]],
- 'wire_cast':
- sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- }
- get_keras_linear_model_predictions(
- features, [price_a, wire_cast, price_b],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- self.assertIn('price_a', my_vars[0].name)
- self.assertIn('price_b', my_vars[1].name)
- self.assertIn('wire_cast', my_vars[2].name)
-
- with ops.Graph().as_default() as g:
- features = {
- 'price_a': [[1.]],
- 'price_b': [[3.]],
- 'wire_cast':
- sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- }
- get_keras_linear_model_predictions(
- features, [wire_cast, price_b, price_a],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- self.assertIn('price_a', my_vars[0].name)
- self.assertIn('price_b', my_vars[1].name)
- self.assertIn('wire_cast', my_vars[2].name)
-
- def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': [[1.], [5.], [7.]], # batchsize = 3
- 'price2': [[3.], [4.]] # batchsize = 2
- }
- with self.assertRaisesRegexp(
- ValueError,
- 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- get_keras_linear_model_predictions(features, [price1, price2])
-
- def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
- 'price2': [[3.], [4.]], # batchsize = 2
- 'price3': [[3.], [4.], [5.]] # batchsize = 3
- }
- with self.assertRaisesRegexp(
- ValueError,
- 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- get_keras_linear_model_predictions(features, [price1, price2, price3])
-
- def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
- 'price2': [[3.], [4.]] # batchsize = 2
- }
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- with _initialized_session() as sess:
- with self.assertRaisesRegexp(errors.OpError,
- 'must have the same size and shape'):
- sess.run(
- predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
-
- def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
- 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
- }
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- with _initialized_session() as sess:
- sess.run(
- predictions,
- feed_dict={
- features['price1']: [[1.], [5.]],
- features['price2']: [[1.], [5.]],
- })
-
- def test_with_numpy_input_fn(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
- input_fn = numpy_io.numpy_input_fn(
- x={
- 'price': np.array([-1., 2., 13., 104.]),
- 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
- },
- batch_size=2,
- shuffle=False)
- features = input_fn()
- net = get_keras_linear_model_predictions(features,
- [price_buckets, body_style])
- # self.assertEqual(1 + 3 + 5, net.shape[1])
- with _initialized_session() as sess:
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
-
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
-
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
-
- coord.request_stop()
- coord.join(threads)
-
- def test_with_1d_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
- # Provides 1-dim tensor and dense tensor.
- features = {
- 'price':
- constant_op.constant([
- -1.,
- 12.,
- ]),
- 'body-style':
- sparse_tensor.SparseTensor(
- indices=((0,), (1,)),
- values=('sedan', 'hardtop'),
- dense_shape=(2,)),
- }
- self.assertEqual(1, features['price'].shape.ndims)
- self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
-
- net = get_keras_linear_model_predictions(features,
- [price_buckets, body_style])
- with _initialized_session() as sess:
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
-
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
-
- def test_with_1d_unknown_shape_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- country = fc_old.categorical_column_with_vocabulary_list(
- 'country', vocabulary_list=['US', 'JP', 'CA'])
-
- # Provides 1-dim tensor and dense tensor.
- features = {
- 'price': array_ops.placeholder(dtypes.float32),
- 'body-style': array_ops.sparse_placeholder(dtypes.string),
- 'country': array_ops.placeholder(dtypes.string),
- }
- self.assertIsNone(features['price'].shape.ndims)
- self.assertIsNone(features['body-style'].get_shape().ndims)
-
- price_data = np.array([-1., 12.])
- body_style_data = sparse_tensor.SparseTensorValue(
- indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
- country_data = np.array(['US', 'CA'])
-
- net = get_keras_linear_model_predictions(
- features, [price_buckets, body_style, country])
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
- with _initialized_session() as sess:
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
- sess.run(
- net,
- feed_dict={
- features['price']: price_data,
- features['body-style']: body_style_data,
- features['country']: country_data
- }))
-
- def test_with_rank_0_feature(self):
- price = fc_old.numeric_column('price')
- features = {
- 'price': constant_op.constant(0),
- }
- self.assertEqual(0, features['price'].shape.ndims)
-
- # Static rank 0 should fail
- with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- get_keras_linear_model_predictions(features, [price])
-
- # Dynamic rank 0 should fail
- features = {
- 'price': array_ops.placeholder(dtypes.float32),
- }
- net = get_keras_linear_model_predictions(features, [price])
- self.assertEqual(1, net.shape[1])
- with _initialized_session() as sess:
- with self.assertRaisesOpError('Feature .* cannot have rank 0'):
- sess.run(net, feed_dict={features['price']: np.array(1)})
-
-
class FeatureLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
@@ -3739,47 +2834,22 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
id_weight_pair.id_tensor.eval())
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_file(
- key='wire',
- vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=self._wire_vocabulary_size,
- num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 2: wire_var[2] = 3
- # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
- self.assertAllClose(((3.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_file(
+ wire_column = fc.categorical_column_with_vocabulary_file(
key='wire',
vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=self._wire_vocabulary_size,
num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -4140,45 +3210,21 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
id_weight_pair.id_tensor.eval())
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_list(
- key='aaa',
- vocabulary_list=('omar', 'stringer', 'marlo'),
- num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 2: wire_var[2] = 3
- # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
- self.assertAllClose(((3.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_list(
+ wire_column = fc.categorical_column_with_vocabulary_list(
key='aaa',
vocabulary_list=('omar', 'stringer', 'marlo'),
num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -4398,39 +3444,18 @@ class IdentityCategoricalColumnTest(test.TestCase):
}))
def test_linear_model(self):
- column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
- self.assertEqual(3, column.num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] = 1
- # weight_var[2] + weight_var[1] = 3+2 = 5
- self.assertAllClose(((1.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
self.assertEqual(3, column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ predictions = model({
column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
@@ -4656,27 +3681,8 @@ class IndicatorColumnTest(test.TestCase):
self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
def test_linear_model(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
- with ops.Graph().as_default():
- features = {
- 'animal':
- sparse_tensor.SparseTensor(
- indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
- }
-
- predictions = fc.linear_model(features, [animal])
- weight_var = get_linear_model_column_var(animal)
- with _initialized_session():
- # All should be zero-initialized.
- self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
- self.assertAllClose([[0.]], predictions.eval())
- weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
- self.assertAllClose([[2. + 3.]], predictions.eval())
-
- def test_keras_linear_model(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
with ops.Graph().as_default():
features = {
'animal':
@@ -4684,8 +3690,9 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- predictions = get_keras_linear_model_predictions(features, [animal])
- weight_var = get_linear_model_column_var(animal)
+ model = fc.LinearModel([animal])
+ predictions = model(features)
+ weight_var, _ = model.variables
with _initialized_session():
# All should be zero-initialized.
self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
@@ -5137,17 +4144,16 @@ class EmbeddingColumnTest(test.TestCase):
return zeros_embedding_values
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer)
with ops.Graph().as_default():
- predictions = fc.linear_model({
- categorical_column.name: sparse_input
- }, (embedding_column,))
+ model = fc.LinearModel((embedding_column,))
+ predictions = model({categorical_column.name: sparse_input})
expected_var_names = (
'linear_model/bias_weights:0',
'linear_model/aaa_embedding/weights:0',
@@ -5189,82 +4195,6 @@ class EmbeddingColumnTest(test.TestCase):
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
- def test_keras_linear_model(self):
- # Inputs.
- batch_size = 4
- vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 4), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(batch_size, 5))
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_shape = (vocabulary_size, embedding_dimension)
- zeros_embedding_values = np.zeros(embedding_shape)
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual(embedding_shape, shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return zeros_embedding_values
-
- # Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
- categorical_column,
- dimension=embedding_dimension,
- initializer=_initializer)
-
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- categorical_column.name: sparse_input
- }, (embedding_column,))
- expected_var_names = (
- 'linear_model/bias_weights:0',
- 'linear_model/aaa_embedding/weights:0',
- 'linear_model/aaa_embedding/embedding_weights:0',
- )
- self.assertItemsEqual(
- expected_var_names,
- [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- trainable_vars = {
- v.name: v
- for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- }
- self.assertItemsEqual(expected_var_names, trainable_vars.keys())
- bias = trainable_vars['linear_model/bias_weights:0']
- embedding_weights = trainable_vars[
- 'linear_model/aaa_embedding/embedding_weights:0']
- linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
- with _initialized_session():
- # Predictions with all zero weights.
- self.assertAllClose(np.zeros((1,)), bias.eval())
- self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights.eval())
- self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
-
- # Predictions with all non-zero weights.
- embedding_weights.assign((
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )).eval()
- linear_weights.assign(((4.,), (6.,))).eval()
- # example 0, ids [2], embedding[0] = [7, 11]
- # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
- # example 2, ids [], embedding[2] = [0, 0]
- # example 3, ids [1], embedding[3] = [3, 5]
- # sum(embeddings * linear_weights)
- # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
- self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
-
def test_feature_layer(self):
# Inputs.
vocabulary_size = 3
@@ -5765,27 +4695,31 @@ class SharedEmbeddingColumnTest(test.TestCase):
return zeros_embedding_values
# Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer)
with ops.Graph().as_default():
- predictions = fc.linear_model({
+ model = fc.LinearModel(
+ (embedding_column_a, embedding_column_b),
+ shared_state_manager=fc.SharedEmbeddingStateManager())
+ predictions = model({
categorical_column_a.name: input_a,
- categorical_column_b.name: input_b,
- }, (embedding_column_a, embedding_column_b))
+ categorical_column_b.name: input_b
+ })
+
# Linear weights do not follow the column name. But this is a rare use
# case, and fixing it would add too much complexity to the code.
expected_var_names = (
'linear_model/bias_weights:0',
- 'linear_model/aaa_bbb_shared_embedding/weights:0',
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ 'linear_model/aaa_shared_embedding/weights:0',
+ 'shared_embedding_state_manager/aaa_bbb_shared_embedding:0',
+ 'linear_model/bbb_shared_embedding/weights:0',
)
self.assertItemsEqual(
expected_var_names,
@@ -5797,102 +4731,11 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
bias = trainable_vars['linear_model/bias_weights:0']
embedding_weights = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
- linear_weights_a = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/weights:0']
- linear_weights_b = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
- with _initialized_session():
- # Predictions with all zero weights.
- self.assertAllClose(np.zeros((1,)), bias.eval())
- self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
- self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
-
- # Predictions with all non-zero weights.
- embedding_weights.assign((
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )).eval()
- linear_weights_a.assign(((4.,), (6.,))).eval()
- # example 0, ids [2], embedding[0] = [7, 11]
- # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
- # sum(embeddings * linear_weights)
- # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
- linear_weights_b.assign(((3.,), (5.,))).eval()
- # example 0, ids [0], embedding[0] = [1, 2]
- # example 1, ids [], embedding[1] = 0, 0]
- # sum(embeddings * linear_weights)
- # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
- self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
-
- def test_keras_linear_model(self):
- # Inputs.
- batch_size = 2
- vocabulary_size = 3
- # -1 values are ignored.
- input_a = np.array([
- [2, -1, -1], # example 0, ids [2]
- [0, 1, -1]
- ]) # example 1, ids [0, 1]
- input_b = np.array([
- [0, -1, -1], # example 0, ids [0]
- [-1, -1, -1]
- ]) # example 1, ids []
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_shape = (vocabulary_size, embedding_dimension)
- zeros_embedding_values = np.zeros(embedding_shape)
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual(embedding_shape, shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return zeros_embedding_values
-
- # Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
- key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
- [categorical_column_a, categorical_column_b],
- dimension=embedding_dimension,
- initializer=_initializer)
-
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- categorical_column_a.name: input_a,
- categorical_column_b.name: input_b,
- }, (embedding_column_a, embedding_column_b))
- # Linear weights do not follow the column name. But this is a rare use
- # case, and fixing it would add too much complexity to the code.
- expected_var_names = (
- 'linear_model/bias_weights:0',
- 'linear_model/aaa_bbb_shared_embedding/weights:0',
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
- )
- self.assertItemsEqual(
- expected_var_names,
- [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- trainable_vars = {
- v.name: v
- for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- }
- self.assertItemsEqual(expected_var_names, trainable_vars.keys())
- bias = trainable_vars['linear_model/bias_weights:0']
- embedding_weights = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ 'shared_embedding_state_manager/aaa_bbb_shared_embedding:0']
linear_weights_a = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ 'linear_model/aaa_shared_embedding/weights:0']
linear_weights_b = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ 'linear_model/bbb_shared_embedding/weights:0']
with _initialized_session():
# Predictions with all zero weights.
self.assertAllClose(np.zeros((1,)), bias.eval())
@@ -6291,13 +5134,14 @@ class WeightedCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2)),
weight_tensor.eval())
- def test_keras_linear_model(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ predictions = model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
@@ -6308,9 +5152,8 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(.5, 1., .1),
dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
@@ -6321,15 +5164,16 @@ class WeightedCategoricalColumnTest(test.TestCase):
# = 3*1 + 2*.1 = 3+.2 = 3.2
self.assertAllClose(((.5,), (3.2,)), predictions.eval())
- def test_keras_linear_model_mismatched_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model_mismatched_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError,
- r'Dimensions.*are not compatible'):
- get_keras_linear_model_predictions({
+ with self.assertRaisesRegexp(
+ ValueError, r'Dimensions.*are not compatible'):
+ model = fc.LinearModel((column,))
+ model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
@@ -6340,122 +5184,23 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (0, 1), (1, 0), (1, 1)),
values=(.5, 11., 1., .1),
dense_shape=(2, 2))
- }, (column,))
-
- def test_keras_linear_model_mismatched_dense_values(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions(
- {
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,),
- sparse_combiner='mean')
- # Disabling the constant folding optimizer here since it changes the
- # error message differently on CPU and GPU.
- config = config_pb2.ConfigProto()
- config.graph_options.rewrite_options.constant_folding = (
- rewriter_config_pb2.RewriterConfig.OFF)
- with _initialized_session(config):
- with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
- predictions.eval()
+ })
- def test_keras_linear_model_mismatched_dense_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model_mismatched_dense_values(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,), sparse_combiner='mean')
+ predictions = model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2)),
- 'values': ((.5,), (1.,), (.1,))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] * weights[0, 0] = 1 * .5 = .5
- # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
- # = 3*1 + 2*.1 = 3+.2 = 3.2
- self.assertAllClose(((.5,), (3.2,)), predictions.eval())
-
- def test_linear_model(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(.5, 1., .1),
- dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] * weights[0, 0] = 1 * .5 = .5
- # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
- # = 3*1 + 2*.1 = 3+.2 = 3.2
- self.assertAllClose(((.5,), (3.2,)), predictions.eval())
-
- def test_linear_model_mismatched_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError, r'Dimensions.*are not compatible'):
- fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (0, 1), (1, 0), (1, 1)),
- values=(.5, 11., 1., .1),
- dense_shape=(2, 2))
- }, (column,))
-
- def test_linear_model_mismatched_dense_values(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = fc.linear_model(
- {
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,),
- sparse_combiner='mean')
+ 'values': ((.5,), (1.,))
+ })
# Disabling the constant folding optimizer here since it changes the
# error message differently on CPU and GPU.
config = config_pb2.ConfigProto()
@@ -6466,20 +5211,21 @@ class WeightedCategoricalColumnTest(test.TestCase):
predictions.eval()
def test_linear_model_mismatched_dense_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
+ model = fc.LinearModel((column,))
+ predictions = model({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
'values': ((.5,), (1.,), (.1,))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())