aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-04-04 15:14:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-04 15:17:25 -0700
commitbf8ad8277258bdf352ddd1df5200e61ba625f7a2 (patch)
treedac2707245c7a900f337237283a72251081b460f /tensorflow/python/feature_column
parent91bf5524560c5bc0783b43717156c7dbb6f798f5 (diff)
Creates a LinearModel (inherits from keras.training.Model) that creates a linear
model. Had to modify the __call__ method in the base layer class so that it could work with feature style inputs in which case we lazily convert the inputs to tensors instead of providing tensors as inputs upfront. PiperOrigin-RevId: 191655445
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/BUILD1
-rw-r--r--tensorflow/python/feature_column/feature_column.py195
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py1230
3 files changed, 1389 insertions, 37 deletions
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 219105d386..295d4ca094 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -43,6 +43,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python/keras",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index e116739bc0..3a315e5c2e 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -139,6 +139,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -460,6 +462,154 @@ def linear_model(features,
return predictions
+class _FCLinearWrapper(base.Layer):
+ """Wraps a _FeatureColumn in a layer for use in a linear model.
+
+ See `linear_model` above.
+ """
+
+ def __init__(self,
+ feature_column,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(_FCLinearWrapper, self).__init__(
+ trainable=trainable, name=name, **kwargs)
+ self._feature_column = feature_column
+ self._units = units
+ self._sparse_combiner = sparse_combiner
+ self._weight_collections = weight_collections
+ self._state = {}
+
+ def build(self, _):
+ self._state = self._feature_column._create_state( # pylint: disable=protected-access
+ self._weight_collections, self.add_variable)
+
+ if isinstance(self._feature_column, _CategoricalColumn):
+ weight = self.add_variable(
+ name='weights',
+ shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ else:
+ num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access
+ weight = self.add_variable(
+ name='weights',
+ shape=[num_elements, self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ ops.add_to_collections(self._weight_collections, weight)
+ self._weight_var = weight
+ self.built = True
+
+ def call(self, builder):
+ weighted_sum = _create_weighted_sum(
+ column=self._feature_column,
+ builder=builder,
+ units=self._units,
+ sparse_combiner=self._sparse_combiner,
+ weight_collections=self._weight_collections,
+ trainable=self.trainable,
+ weight_var=self._weight_var,
+ state=self._state)
+ return weighted_sum
+
+
+class _BiasLayer(base.Layer):
+ """A layer for the bias term.
+ """
+
+ def __init__(self,
+ units=1,
+ trainable=True,
+ weight_collections=None,
+ name=None,
+ **kwargs):
+ super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs)
+ self._units = units
+ self._weight_collections = weight_collections
+
+ def build(self, _):
+ self._bias_variable = self.add_variable(
+ 'bias_weights',
+ shape=[self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ ops.add_to_collections(self._weight_collections, self._bias_variable)
+ self.built = True
+
+ def call(self, _):
+ return self._bias_variable
+
+
+class _LinearModel(training.Model):
+ """Creates a linear model using feature columns.
+ """
+
+ def __init__(self,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(_LinearModel, self).__init__(name=name, **kwargs)
+ self._feature_columns = _clean_feature_columns(feature_columns)
+ self._weight_collections = list(weight_collections or [])
+ if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
+ self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
+
+ column_layers = {}
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ with variable_scope.variable_scope(
+ None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access
+ column_name = vs.name
+ column_layer = _FCLinearWrapper(column, units, sparse_combiner,
+ self._weight_collections, trainable,
+ column_name, **kwargs)
+ column_layers[column_name] = column_layer
+ self._column_layers = self._add_layers(column_layers)
+ self._bias_layer = _BiasLayer(
+ units=units,
+ trainable=trainable,
+ weight_collections=self._weight_collections,
+ name='bias_layer',
+ **kwargs)
+
+ def call(self, features):
+ for column in self._feature_columns:
+ if not isinstance(column, (_DenseColumn, _CategoricalColumn)):
+ raise ValueError(
+ 'Items of feature_columns must be either a '
+ '_DenseColumn or _CategoricalColumn. Given: {}'.format(column))
+ weighted_sums = []
+ ordered_columns = []
+ builder = _LazyBuilder(features)
+ for layer in sorted(self._column_layers.values(), key=lambda x: x.name):
+ ordered_columns.append(layer._feature_column) # pylint: disable=protected-access
+ weighted_sum = layer(builder)
+ weighted_sums.append(weighted_sum)
+
+ _verify_static_batch_size_equality(weighted_sums, ordered_columns)
+ predictions_no_bias = math_ops.add_n(
+ weighted_sums, name='weighted_sum_no_bias')
+ predictions = nn_ops.bias_add(
+ predictions_no_bias, self._bias_layer(builder), name='weighted_sum') # pylint: disable=not-callable
+ return predictions
+
+ def _add_layers(self, layers):
+ # "Magic" required for keras.Model classes to track all the variables in
+ # a list of layers.Layer objects.
+ # TODO(ashankar): Figure out API so user code doesn't have to do this.
+ for name, layer in layers.items():
+ setattr(self, 'layer-%s' % name, layer)
+ return layers
+
+
def _transform_features(features, feature_columns):
"""Returns transformed features based on features columns passed in.
@@ -1713,6 +1863,7 @@ def _create_weighted_sum(column,
sparse_combiner,
weight_collections,
trainable,
+ weight_var=None,
state=None):
"""Creates a weighted sum for a dense or sparse column for linear_model."""
if isinstance(column, _CategoricalColumn):
@@ -1722,7 +1873,8 @@ def _create_weighted_sum(column,
units=units,
sparse_combiner=sparse_combiner,
weight_collections=weight_collections,
- trainable=trainable)
+ trainable=trainable,
+ weight_var=weight_var)
else:
return _create_dense_column_weighted_sum(
column=column,
@@ -1730,6 +1882,7 @@ def _create_weighted_sum(column,
units=units,
weight_collections=weight_collections,
trainable=trainable,
+ weight_var=weight_var,
state=state)
@@ -1738,6 +1891,7 @@ def _create_dense_column_weighted_sum(column,
units,
weight_collections,
trainable,
+ weight_var=None,
state=None):
"""Create a weighted sum of a dense column for linear_model."""
if state is not None:
@@ -1754,12 +1908,15 @@ def _create_dense_column_weighted_sum(column,
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- weight = variable_scope.get_variable(
- name='weights',
- shape=[num_elements, units],
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
+ if weight_var is not None:
+ weight = weight_var
+ else:
+ weight = variable_scope.get_variable(
+ name='weights',
+ shape=[num_elements, units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
return math_ops.matmul(tensor, weight, name='weighted_sum')
@@ -1809,8 +1966,13 @@ class _CategoricalColumn(_FeatureColumn):
pass
-def _create_categorical_column_weighted_sum(
- column, builder, units, sparse_combiner, weight_collections, trainable):
+def _create_categorical_column_weighted_sum(column,
+ builder,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ weight_var=None):
"""Create a weighted sum of a categorical column for linear_model."""
sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access
builder,
@@ -1824,12 +1986,15 @@ def _create_categorical_column_weighted_sum(
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
- weight = variable_scope.get_variable(
- name='weights',
- shape=(column._num_buckets, units), # pylint: disable=protected-access
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
+ if weight_var is not None:
+ weight = weight_var
+ else:
+ weight = variable_scope.get_variable(
+ name='weights',
+ shape=(column._num_buckets, units), # pylint: disable=protected-access
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
return _safe_embedding_lookup_sparse(
weight,
id_tensor,
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 4006a76bb4..07588af37e 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -34,6 +34,7 @@ from tensorflow.python.feature_column.feature_column import _CategoricalColumn
from tensorflow.python.feature_column.feature_column import _DenseColumn
from tensorflow.python.feature_column.feature_column import _FeatureColumn
from tensorflow.python.feature_column.feature_column import _LazyBuilder
+from tensorflow.python.feature_column.feature_column import _LinearModel
from tensorflow.python.feature_column.feature_column import _transform_features
from tensorflow.python.feature_column.feature_column import InputLayer
from tensorflow.python.framework import constant_op
@@ -339,6 +340,20 @@ class NumericColumnTest(test.TestCase):
sess.run(price_var.assign([[10.]]))
self.assertAllClose([[10.], [50.]], predictions.eval())
+ def test_keras_linear_model(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[10.], [50.]], predictions.eval())
+
class BucketizedColumnTest(test.TestCase):
@@ -561,6 +576,62 @@ class BucketizedColumnTest(test.TestCase):
sess.run(bias.assign([1.]))
self.assertAllClose([[81.], [141.]], predictions.eval())
+ def test_keras_linear_model_one_input_value(self):
+ """Tests _LinearModel for input with shape=[1]."""
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1.], [1.], [5.], [6.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [bucketized_price])
+ bias = get_keras_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight variable per bucket, all initialized to zero.
+ self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 1st bucket, whose weight is 20.
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 4th bucket, whose weight is 50.
+ self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
+
+ def test_keras_linear_model_two_input_values(self):
+ """Tests _LinearModel for input with shape=[2]."""
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1., 1.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [bucketized_price])
+ bias = get_keras_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight per bucket per input column, all initialized to zero.
+ self.assertAllClose(
+ [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],
+ [60.], [70.], [80.], [90.], [100.]]))
+ # 1st example:
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 6th bucket, whose weight is 70.
+ # 2nd example:
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 9th bucket, whose weight is 100.
+ self.assertAllClose([[80.], [140.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[81.], [141.]], predictions.eval())
+
class HashedCategoricalColumnTest(test.TestCase):
@@ -767,6 +838,28 @@ class HashedCategoricalColumnTest(test.TestCase):
# 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
self.assertAllClose(((4.,), (6.,)), predictions.eval())
+ def test_keras_linear_model(self):
+ wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_keras_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 3: wire_var[3] = 4
+ # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
+ self.assertAllClose(((4.,), (6.,)), predictions.eval())
+
class CrossedColumnTest(test.TestCase):
@@ -1060,6 +1153,96 @@ class CrossedColumnTest(test.TestCase):
dense_shape=(2, 2)),
}, (crossed,))
+ def test_keras_linear_model(self):
+ """Tests _LinearModel.
+
+ Uses data from test_get_sparse_tesnsors_simple.
+ """
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'a':
+ constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+ bias = get_keras_linear_model_bias()
+ crossed_var = get_linear_model_column_var(crossed)
+ with _initialized_session() as sess:
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
+ crossed_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
+ # Expected ids after cross = (1, 0, 1, 3, 4, 2)
+ self.assertAllClose(((3.,), (14.,)), predictions.eval())
+ sess.run(bias.assign((.1,)))
+ self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
+
+ def test_keras_linear_model_with_weights(self):
+
+ class _TestColumnWithWeights(_CategoricalColumn):
+ """Produces sparse IDs and sparse weights."""
+
+ @property
+ def name(self):
+ return 'test_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {
+ self.name:
+ parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name):
+ parsing_ops.VarLenFeature(dtypes.float32),
+ }
+
+ @property
+ def _num_buckets(self):
+ return 5
+
+ def _transform_feature(self, inputs):
+ return (inputs.get(self.name),
+ inputs.get('{}_weights'.format(self.name)))
+
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Populates both id_tensor and weight_tensor."""
+ ids_and_weights = inputs.get(self)
+ return _CategoricalColumn.IdWeightPair(
+ id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
+
+ t = _TestColumnWithWeights()
+ crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
+ get_keras_linear_model_predictions({
+ t.name:
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[0, 1, 2],
+ dense_shape=(2, 2)),
+ '{}_weights'.format(t.name):
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[1., 10., 2.],
+ dense_shape=(2, 2)),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+
def get_linear_model_bias():
with variable_scope.variable_scope('linear_model', reuse=True):
@@ -1071,6 +1254,28 @@ def get_linear_model_column_var(column):
'linear_model/' + column.name)[0]
+def get_keras_linear_model_bias():
+ with variable_scope.variable_scope('linear_model', reuse=True):
+ with variable_scope.variable_scope('bias_layer', reuse=True):
+ return variable_scope.get_variable('bias_weights')
+
+
+def get_keras_linear_model_predictions(features,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True):
+ keras_linear_model = _LinearModel(
+ feature_columns,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ name='linear_model')
+ return keras_linear_model(features) # pylint: disable=not-callable
+
+
@test_util.with_c_api
class LinearModelTest(test.TestCase):
@@ -1698,6 +1903,629 @@ class LinearModelTest(test.TestCase):
sess.run(net, feed_dict={features['price']: np.array(1)})
+@test_util.with_c_api
+class _LinearModelTest(test.TestCase):
+
+ def test_raises_if_empty_feature_columns(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'feature_columns must not be empty'):
+ get_keras_linear_model_predictions(features={}, feature_columns=[])
+
+ def test_should_be_feature_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]}, feature_columns='NotSupported')
+
+ def test_should_be_dense_or_categorical_column(self):
+
+ class NotSupportedColumn(_FeatureColumn):
+
+ @property
+ def name(self):
+ return 'NotSupportedColumn'
+
+ def _transform_feature(self, cache):
+ pass
+
+ @property
+ def _parse_example_spec(self):
+ pass
+
+ with self.assertRaisesRegexp(
+ ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc.linear_model(
+ features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')})
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]},
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])
+
+ def test_dense_bias(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ sess.run(price_var.assign([[10.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions.eval())
+
+ def test_sparse_bias(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(features, [wire_cast])
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_and_sparse_bias(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [wire_cast, price])
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[1015.], [10065.]], predictions.eval())
+
+ def test_dense_and_sparse_column(self):
+ """When the column is both dense and sparse, uses sparse tensors."""
+
+ class _DenseAndSparseColumn(_DenseColumn, _CategoricalColumn):
+
+ @property
+ def name(self):
+ return 'dense_and_sparse_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {self.name: parsing_ops.VarLenFeature(self.dtype)}
+
+ def _transform_feature(self, inputs):
+ return inputs.get(self.name)
+
+ @property
+ def _variable_shape(self):
+ raise ValueError('Should not use this method.')
+
+ def _get_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ raise ValueError('Should not use this method.')
+
+ @property
+ def _num_buckets(self):
+ return 4
+
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ sp_tensor = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[2, 0, 3],
+ dense_shape=[2, 2])
+ return _CategoricalColumn.IdWeightPair(sp_tensor, None)
+
+ dense_and_sparse_column = _DenseAndSparseColumn()
+ with ops.Graph().as_default():
+ sp_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {dense_and_sparse_column.name: sp_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [dense_and_sparse_column])
+ bias = get_keras_linear_model_bias()
+ dense_and_sparse_column_var = get_linear_model_column_var(
+ dense_and_sparse_column)
+ with _initialized_session() as sess:
+ sess.run(
+ dense_and_sparse_column_var.assign([[10.], [100.], [1000.],
+ [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_multi_output(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(
+ features, [price], units=3)
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((1, 3)), price_var.eval())
+ sess.run(price_var.assign([[10., 100., 1000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
+ predictions.eval())
+
+ def test_sparse_multi_output(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [wire_cast], units=3)
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
+ sess.run(
+ wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],
+ [1000., 1100.,
+ 1200.], [10000., 11000., 12000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
+ predictions.eval())
+
+ def test_dense_multi_dimension(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_sparse_multi_rank(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = array_ops.sparse_placeholder(dtypes.string)
+ wire_value = sparse_tensor.SparseTensorValue(
+ values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
+ indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
+ dense_shape=[2, 2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(features, [wire_cast])
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
+ self.assertAllClose(
+ np.zeros((2, 1)),
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ self.assertAllClose(
+ [[1010.], [11000.]],
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+
+ def test_sparse_combiner(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [wire_cast], sparse_combiner='mean')
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [5010.]], predictions.eval())
+
+ def test_dense_multi_dimension_multi_output(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(
+ features, [price], units=3)
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((2, 3)), price_var.eval())
+ sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
+ sess.run(bias.assign([2., 3., 4.]))
+ self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
+ predictions.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ if ops._USE_C_API:
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ predictions = get_keras_linear_model_predictions(features, [price])
+ else:
+ predictions = get_keras_linear_model_predictions(features, [price])
+ with _initialized_session():
+ with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
+ predictions.eval()
+
+ def test_dense_reshaping(self):
+ price = fc.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': [[[1., 2.]], [[5., 6.]]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_dense_multi_column(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ bias = get_keras_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price1_var.eval())
+ self.assertAllClose([[0.]], price2_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price1_var.assign([[10.], [100.]]))
+ sess.run(price2_var.assign([[1000.]]))
+ sess.run(bias.assign([7.]))
+ self.assertAllClose([[3217.], [4657.]], predictions.eval())
+
+ def test_dense_collection(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(
+ features, [price], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ self.assertIn(bias, my_vars)
+ self.assertIn(price_var, my_vars)
+
+ def test_sparse_collection(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(
+ features, [wire_cast], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, my_vars)
+ self.assertIn(wire_cast_var, my_vars)
+
+ def test_dense_trainable_default(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(features, [price])
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(price_var, trainable_vars)
+
+ def test_sparse_trainable_default(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(features, [wire_cast])
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(wire_cast_var, trainable_vars)
+
+ def test_dense_trainable_false(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(features, [price], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_sparse_trainable_false(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(features, [wire_cast], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_column_order(self):
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ get_keras_linear_model_predictions(
+ features, [price_a, wire_cast, price_b],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ get_keras_linear_model_predictions(
+ features, [wire_cast, price_b, price_a],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ def test_static_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1.], [5.], [7.]], # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ get_keras_linear_model_predictions(features, [price1, price2])
+
+ def test_subset_of_static_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]], # batchsize = 2
+ 'price3': [[3.], [4.], [5.]] # batchsize = 3
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ get_keras_linear_model_predictions(features, [price1, price2, price3])
+
+ def test_runtime_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ with _initialized_session() as sess:
+ with self.assertRaisesRegexp(errors.OpError,
+ 'must have the same size and shape'):
+ sess.run(
+ predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+ def test_runtime_batch_size_matches(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ }
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ with _initialized_session() as sess:
+ sess.run(
+ predictions,
+ feed_dict={
+ features['price1']: [[1.], [5.]],
+ features['price2']: [[1.], [5.]],
+ })
+
+ def test_with_numpy_input_fn(self):
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([-1., 2., 13., 104.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = get_keras_linear_model_predictions(features,
+ [price_buckets, body_style])
+ # self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ bias = get_keras_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_with_1d_sparse_tensor(self):
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price':
+ constant_op.constant([
+ -1.,
+ 12.,
+ ]),
+ 'body-style':
+ sparse_tensor.SparseTensor(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,)),
+ }
+ self.assertEqual(1, features['price'].shape.ndims)
+ self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+
+ net = get_keras_linear_model_predictions(features,
+ [price_buckets, body_style])
+ with _initialized_session() as sess:
+ bias = get_keras_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
+
+ def test_with_1d_unknown_shape_sparse_tensor(self):
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ country = fc.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ 'body-style': array_ops.sparse_placeholder(dtypes.string),
+ 'country': array_ops.placeholder(dtypes.string),
+ }
+ self.assertIsNone(features['price'].shape.ndims)
+ self.assertIsNone(features['body-style'].get_shape().ndims)
+
+ price_data = np.array([-1., 12.])
+ body_style_data = sparse_tensor.SparseTensorValue(
+ indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
+ country_data = np.array(['US', 'CA'])
+
+ net = get_keras_linear_model_predictions(
+ features, [price_buckets, body_style, country])
+ bias = get_keras_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+ with _initialized_session() as sess:
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
+ sess.run(
+ net,
+ feed_dict={
+ features['price']: price_data,
+ features['body-style']: body_style_data,
+ features['country']: country_data
+ }))
+
+ def test_with_rank_0_feature(self):
+ price = fc.numeric_column('price')
+ features = {
+ 'price': constant_op.constant(0),
+ }
+ self.assertEqual(0, features['price'].shape.ndims)
+
+ # Static rank 0 should fail
+ with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+ get_keras_linear_model_predictions(features, [price])
+
+ # Dynamic rank 0 should fail
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ }
+ net = get_keras_linear_model_predictions(features, [price])
+ self.assertEqual(1, net.shape[1])
+ with _initialized_session() as sess:
+ with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+ sess.run(net, feed_dict={features['price']: np.array(1)})
+
+
class InputLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
@@ -2715,6 +3543,32 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
self.assertAllClose(((3.,), (5.,)), predictions.eval())
+ def test_keras_linear_model(self):
+ wire_column = fc.categorical_column_with_vocabulary_file(
+ key='wire',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_keras_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
class VocabularyListCategoricalColumnTest(test.TestCase):
@@ -3082,6 +3936,31 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
self.assertAllClose(((3.,), (5.,)), predictions.eval())
+ def test_keras_linear_model(self):
+ wire_column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_keras_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
class IdentityCategoricalColumnTest(test.TestCase):
@@ -3306,6 +4185,28 @@ class IdentityCategoricalColumnTest(test.TestCase):
# weight_var[2] + weight_var[1] = 3+2 = 5
self.assertAllClose(((1.,), (5.,)), predictions.eval())
+ def test_keras_linear_model(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual(3, column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_keras_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] = 1
+ # weight_var[2] + weight_var[1] = 3+2 = 5
+ self.assertAllClose(((1.,), (5.,)), predictions.eval())
+
class TransformFeaturesTest(test.TestCase):
@@ -3537,6 +4438,25 @@ class IndicatorColumnTest(test.TestCase):
weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
self.assertAllClose([[2. + 3.]], predictions.eval())
+ def test_keras_linear_model(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+
+ predictions = get_keras_linear_model_predictions(features, [animal])
+ weight_var = get_linear_model_column_var(animal)
+ with _initialized_session():
+ # All should be zero-initialized.
+ self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
+ self.assertAllClose([[0.]], predictions.eval())
+ weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
+ self.assertAllClose([[2. + 3.]], predictions.eval())
+
def test_input_layer(self):
animal = fc.indicator_column(
fc.categorical_column_with_identity('animal', num_buckets=4))
@@ -3727,8 +4647,8 @@ class EmbeddingColumnTest(test.TestCase):
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
with _initialized_session():
self.assertAllEqual(embedding_values, global_vars[0].eval())
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
@@ -3752,6 +4672,7 @@ class EmbeddingColumnTest(test.TestCase):
(3., 5.), # id 1
(7., 11.) # id 2
)
+
def _initializer(shape, dtype, partition_info):
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
self.assertEqual(dtypes.float32, dtype)
@@ -3774,20 +4695,21 @@ class EmbeddingColumnTest(test.TestCase):
categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column = fc.embedding_column(
- categorical_column, dimension=embedding_dimension,
+ categorical_column,
+ dimension=embedding_dimension,
initializer=_initializer)
# Create embedding_weights variable.
- weight_collections = [ops.GraphKeys.GLOBAL_VARIABLES,
- ops.GraphKeys.MODEL_VARIABLES]
+ weight_collections = [
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES
+ ]
state = embedding_column._create_state(weight_collections)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column._get_dense_tensor(
_LazyBuilder({
'aaa': sparse_input
- }),
- state=state)
+ }), state=state)
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
@@ -4087,6 +5009,82 @@ class EmbeddingColumnTest(test.TestCase):
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
+ def test_keras_linear_model(self):
+ # Inputs.
+ batch_size = 4
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(batch_size, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ categorical_column.name: sparse_input
+ }, (embedding_column,))
+ expected_var_names = (
+ 'linear_model/bias_layer/bias_weights:0',
+ 'linear_model/aaa_embedding/weights:0',
+ 'linear_model/aaa_embedding/embedding_weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v
+ for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_layer/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_embedding/embedding_weights:0']
+ linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # example 2, ids [], embedding[2] = [0, 0]
+ # example 3, ids [1], embedding[3] = [3, 5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
+ self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
+
def test_input_layer(self):
# Inputs.
vocabulary_size = 3
@@ -4509,8 +5507,8 @@ class SharedEmbeddingColumnTest(test.TestCase):
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
embedding_var = global_vars[0]
with _initialized_session():
self.assertAllEqual(embedding_values, embedding_var.eval())
@@ -4521,16 +5519,15 @@ class SharedEmbeddingColumnTest(test.TestCase):
# Inputs.
vocabulary_size = 3
# -1 values are ignored.
- input_a = np.array(
- [[2, -1, -1], # example 0, ids [2]
- [0, 1, -1]]) # example 1, ids [0, 1]
- input_b = np.array(
- [[0, -1, -1], # example 0, ids [0]
- [-1, -1, -1]]) # example 1, ids []
- input_features = {
- 'aaa': input_a,
- 'bbb': input_b
- }
+ input_a = np.array([
+ [2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]
+ ]) # example 1, ids [0, 1]
+ input_b = np.array([
+ [0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]
+ ]) # example 1, ids []
+ input_features = {'aaa': input_a, 'bbb': input_b}
# Embedding variable.
embedding_dimension = 2
@@ -4539,6 +5536,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
(3., 5.), # id 1
(7., 11.) # id 2
)
+
def _initializer(shape, dtype, partition_info):
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
self.assertEqual(dtypes.float32, dtype)
@@ -4566,11 +5564,13 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='bbb', num_buckets=vocabulary_size)
embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b],
- dimension=embedding_dimension, initializer=_initializer)
+ dimension=embedding_dimension,
+ initializer=_initializer)
# Create state.
- weight_collections = [ops.GraphKeys.GLOBAL_VARIABLES,
- ops.GraphKeys.MODEL_VARIABLES]
+ weight_collections = [
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES
+ ]
state = embedding_column_a._create_state(weight_collections)
# Provide sparse input and get dense result.
@@ -4731,6 +5731,97 @@ class SharedEmbeddingColumnTest(test.TestCase):
# = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
+ def test_keras_linear_model(self):
+ # Inputs.
+ batch_size = 2
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array([
+ [2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]
+ ]) # example 1, ids [0, 1]
+ input_b = np.array([
+ [0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]
+ ]) # example 1, ids []
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ categorical_column_a.name: input_a,
+ categorical_column_b.name: input_b,
+ }, (embedding_column_a, embedding_column_b))
+ # Linear weights do not follow the column name. But this is a rare use
+ # case, and fixing it would add too much complexity to the code.
+ expected_var_names = (
+ 'linear_model/bias_layer/bias_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v
+ for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_layer/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ linear_weights_a = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ linear_weights_b = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights_a.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
+ linear_weights_b.assign(((3.,), (5.,))).eval()
+ # example 0, ids [0], embedding[0] = [1, 2]
+ # example 1, ids [], embedding[1] = 0, 0]
+ # sum(embeddings * linear_weights)
+ # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
+ self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
+
def _test_input_layer(self, trainable=True):
# Inputs.
vocabulary_size = 3
@@ -5016,6 +6107,101 @@ class WeightedCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2)),
weight_tensor.eval())
+ def test_keras_linear_model(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(.5, 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_keras_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ def test_keras_linear_model_mismatched_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(ValueError,
+ r'Dimensions.*are not compatible'):
+ get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (0, 1), (1, 0), (1, 1)),
+ values=(.5, 11., 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+
+ def test_keras_linear_model_mismatched_dense_values(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,))
+ }, (column,))
+ with _initialized_session():
+ with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
+ predictions.eval()
+
+ def test_keras_linear_model_mismatched_dense_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,), (.1,))
+ }, (column,))
+ bias = get_keras_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
def test_linear_model(self):
column = fc.weighted_categorical_column(
categorical_column=fc.categorical_column_with_identity(