aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined_test.py107
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py109
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py64
-rw-r--r--tensorflow/python/feature_column/BUILD1
-rw-r--r--tensorflow/python/feature_column/feature_column.py4
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py869
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py2590
7 files changed, 3420 insertions, 324 deletions
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
index ae968e717a..ab945d7b1a 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
@@ -317,16 +317,10 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
- def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
- input_dimension, label_dimension, batch_size,
- fc_impl):
- linear_feature_columns = [
- fc_impl.numeric_column('x', shape=(input_dimension,))
- ]
- dnn_feature_columns = [
- fc_impl.numeric_column('x', shape=(input_dimension,))
- ]
- feature_columns = linear_feature_columns + dnn_feature_columns
+ def _test_complete_flow_helper(
+ self, linear_feature_columns, dnn_feature_columns, feature_spec,
+ train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
+ label_dimension, batch_size):
est = dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=linear_feature_columns,
dnn_hidden_units=(2, 2),
@@ -351,14 +345,63 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
# EXPORT
- feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self, fc_impl):
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size,
+ fc_impl):
+ linear_feature_columns = [
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
+ dnn_feature_columns = [
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
+ feature_columns = linear_feature_columns + dnn_feature_columns
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
+ self._test_complete_flow_helper(linear_feature_columns, dnn_feature_columns,
+ feature_spec, train_input_fn, eval_input_fn,
+ predict_input_fn, input_dimension,
+ label_dimension, batch_size)
+
+ def _test_complete_flow_mix1(self, train_input_fn, eval_input_fn,
+ predict_input_fn, input_dimension,
+ label_dimension, batch_size, fc_impl):
+ del fc_impl
+ linear_feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ dnn_feature_columns = [
+ feature_column_v2.numeric_column('x', shape=(input_dimension,))
+ ]
+ feature_columns = linear_feature_columns + dnn_feature_columns
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ self._test_complete_flow_helper(linear_feature_columns, dnn_feature_columns,
+ feature_spec, train_input_fn, eval_input_fn,
+ predict_input_fn, input_dimension,
+ label_dimension, batch_size)
+
+ def _test_complete_flow_mix2(self, train_input_fn, eval_input_fn,
+ predict_input_fn, input_dimension,
+ label_dimension, batch_size, fc_impl):
+ del fc_impl
+ linear_feature_columns = [
+ feature_column_v2.numeric_column('x', shape=(input_dimension,))
+ ]
+ dnn_feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ feature_columns = linear_feature_columns + dnn_feature_columns
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ self._test_complete_flow_helper(linear_feature_columns, dnn_feature_columns,
+ feature_spec, train_input_fn, eval_input_fn,
+ predict_input_fn, input_dimension,
+ label_dimension, batch_size)
+
+ def _test_numpy_input_fn_helper(self, fc_impl, fn_to_run):
"""Tests complete flow with numpy_input_fn."""
label_dimension = 2
batch_size = 10
@@ -381,7 +424,7 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
batch_size=batch_size,
shuffle=False)
- self._test_complete_flow(
+ fn_to_run(
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
predict_input_fn=predict_input_fn,
@@ -390,7 +433,16 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
batch_size=batch_size,
fc_impl=fc_impl)
- def test_pandas_input_fn(self, fc_impl):
+ def test_numpy_input_fn_basic(self, fc_impl):
+ self._test_numpy_input_fn_helper(fc_impl, self._test_complete_flow)
+
+ def test_numpy_input_fn_mix1(self, fc_impl):
+ self._test_numpy_input_fn_helper(fc_impl, self._test_complete_flow_mix1)
+
+ def test_numpy_input_fn_mix2(self, fc_impl):
+ self._test_numpy_input_fn_helper(fc_impl, self._test_complete_flow_mix2)
+
+ def _test_pandas_input_fn_helper(self, fc_impl, fn_to_run):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -415,7 +467,7 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
batch_size=batch_size,
shuffle=False)
- self._test_complete_flow(
+ fn_to_run(
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
predict_input_fn=predict_input_fn,
@@ -424,7 +476,16 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
batch_size=batch_size,
fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self, fc_impl):
+ def test_pandas_input_fn_basic(self, fc_impl):
+ self._test_pandas_input_fn_helper(fc_impl, self._test_complete_flow)
+
+ def test_pandas_input_fn_mix1(self, fc_impl):
+ self._test_pandas_input_fn_helper(fc_impl, self._test_complete_flow_mix1)
+
+ def test_pandas_input_fn_mix2(self, fc_impl):
+ self._test_pandas_input_fn_helper(fc_impl, self._test_complete_flow_mix2)
+
+ def _test_input_fn_from_parse_example_helper(self, fc_impl, fn_to_run):
"""Tests complete flow with input_fn constructed from parse_example."""
label_dimension = 2
batch_size = 10
@@ -466,7 +527,7 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
features.pop('y')
return features, None
- self._test_complete_flow(
+ fn_to_run(
train_input_fn=_train_input_fn,
eval_input_fn=_eval_input_fn,
predict_input_fn=_predict_input_fn,
@@ -475,6 +536,18 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
batch_size=batch_size,
fc_impl=fc_impl)
+ def test_input_fn_from_parse_example_basic(self, fc_impl):
+ self._test_input_fn_from_parse_example_helper(fc_impl,
+ self._test_complete_flow)
+
+ def test_input_fn_from_parse_example_mix1(self, fc_impl):
+ self._test_input_fn_from_parse_example_helper(fc_impl,
+ self._test_complete_flow_mix1)
+
+ def test_input_fn_from_parse_example_mix2(self, fc_impl):
+ self._test_input_fn_from_parse_example_helper(fc_impl,
+ self._test_complete_flow_mix2)
+
# A function to mimic dnn-classifier init reuse same tests.
def _dnn_classifier_fn(hidden_units,
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index cd66d0a3bd..71d7e54783 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -34,6 +34,7 @@ from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -479,6 +480,60 @@ class BaseDNNModelFnTest(object):
else:
self.fail('Invalid mode: {}'.format(mode))
+ def test_multi_feature_column_mix_multi_dim_logits(self):
+ """Tests multiple feature columns and multi-dimensional logits.
+
+ All numbers are the same as test_multi_dim_input_multi_dim_logits. The only
+ difference is that the input consists of two 1D feature columns, instead of
+ one 2D feature column.
+ """
+ base_global_step = 100
+ create_checkpoint((
+ ([[.6, .5], [-.6, -.5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
+ ), base_global_step, self._model_dir)
+ hidden_units = (2, 2)
+ logits_dimension = 3
+ inputs = ([[10.]], [[8.]])
+ expected_logits = [[-0.48, 0.48, 0.39]]
+
+ for mode in [
+ model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT
+ ]:
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ head = mock_head(
+ self,
+ hidden_units=hidden_units,
+ logits_dimension=logits_dimension,
+ expected_logits=expected_logits)
+ estimator_spec = self._dnn_model_fn(
+ features={
+ 'age': constant_op.constant(inputs[0]),
+ 'height': constant_op.constant(inputs[1])
+ },
+ labels=constant_op.constant([[1]]),
+ mode=mode,
+ head=head,
+ hidden_units=hidden_units,
+ feature_columns=[
+ feature_column.numeric_column('age'),
+ feature_column_v2.numeric_column('height')
+ ],
+ optimizer=mock_optimizer(self, hidden_units))
+ with monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=self._model_dir) as sess:
+ if mode == model_fn.ModeKeys.TRAIN:
+ sess.run(estimator_spec.train_op)
+ elif mode == model_fn.ModeKeys.EVAL:
+ sess.run(estimator_spec.loss)
+ elif mode == model_fn.ModeKeys.PREDICT:
+ sess.run(estimator_spec.predictions)
+ else:
+ self.fail('Invalid mode: {}'.format(mode))
+
def test_features_tensor_raises_value_error(self):
"""Tests that passing a Tensor for features raises a ValueError."""
hidden_units = (2, 2)
@@ -806,6 +861,60 @@ class BaseDNNLogitFnTest(object):
checkpoint_dir=self._model_dir) as sess:
self.assertAllClose(expected_logits, sess.run(logits))
+ def test_multi_feature_column_mix_multi_dim_logits(self):
+ """Tests multiple feature columns and multi-dimensional logits.
+
+ All numbers are the same as test_multi_dim_input_multi_dim_logits. The only
+ difference is that the input consists of two 1D feature columns, instead of
+ one 2D feature column.
+ """
+ base_global_step = 100
+ create_checkpoint((
+ ([[.6, .5], [-.6, -.5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
+ ), base_global_step, self._model_dir)
+
+ hidden_units = (2, 2)
+ logits_dimension = 3
+ inputs = ([[10.]], [[8.]])
+ expected_logits = [[-0.48, 0.48, 0.39]]
+
+ for mode in [
+ model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT
+ ]:
+ with ops.Graph().as_default():
+ # Global step needed for MonitoredSession, which is in turn used to
+ # explicitly set variable weights through a checkpoint.
+ training_util.create_global_step()
+ # Use a variable scope here with 'dnn', emulating the dnn model_fn, so
+ # the checkpoint naming is shared.
+ with variable_scope.variable_scope('dnn'):
+ input_layer_partitioner = (
+ partitioned_variables.min_max_variable_partitioner(
+ max_partitions=0, min_slice_size=64 << 20))
+ logit_fn = self._dnn_logit_fn_builder(
+ units=logits_dimension,
+ hidden_units=hidden_units,
+ feature_columns=[
+ feature_column.numeric_column('age'),
+ feature_column_v2.numeric_column('height')
+ ],
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=input_layer_partitioner,
+ batch_norm=False)
+ logits = logit_fn(
+ features={
+ 'age': constant_op.constant(inputs[0]),
+ 'height': constant_op.constant(inputs[1])
+ },
+ mode=mode)
+ with monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=self._model_dir) as sess:
+ self.assertAllClose(expected_logits, sess.run(logits))
+
class BaseDNNWarmStartingTest(object):
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index 827352a70b..2cfa2a8e15 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -400,6 +400,45 @@ class BaseLinearRegressorEvaluationTest(object):
# [213.0, 421.0], while label is [213., 421.]. Loss = 0.
self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])
+ def test_evaluation_for_multiple_feature_columns_mix(self):
+ with ops.Graph().as_default():
+ variables_lib.Variable([[10.0]], name=AGE_WEIGHT_NAME)
+ variables_lib.Variable([[2.0]], name=HEIGHT_WEIGHT_NAME)
+ variables_lib.Variable([5.0], name=BIAS_NAME)
+ variables_lib.Variable(
+ 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ batch_size = 2
+ feature_columns = [
+ feature_column.numeric_column('age'),
+ feature_column_v2.numeric_column('height')
+ ]
+
+ def _input_fn():
+ features_ds = dataset_ops.Dataset.from_tensor_slices({
+ 'age': np.array([20, 40]),
+ 'height': np.array([4, 8])
+ })
+ labels_ds = dataset_ops.Dataset.from_tensor_slices(
+ np.array([[213.], [421.]]))
+ return (dataset_ops.Dataset.zip((features_ds, labels_ds))
+ .batch(batch_size).repeat(None))
+
+ est = self._linear_regressor_fn(
+ feature_columns=feature_columns, model_dir=self._model_dir)
+
+ eval_metrics = est.evaluate(input_fn=_input_fn, steps=1)
+ self.assertItemsEqual(
+ (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
+
+ # Logit is [(20. * 10.0 + 4 * 2.0 + 5.0), (40. * 10.0 + 8 * 2.0 + 5.0)] =
+ # [213.0, 421.0], while label is [213., 421.]. Loss = 0.
+ self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])
+
class BaseLinearRegressorPredictTest(object):
@@ -497,6 +536,31 @@ class BaseLinearRegressorPredictTest(object):
# x0 * weight0 + x1 * weight1 + bias = 2. * 10. + 3. * 20 + .2 = 80.2
self.assertAllClose([[80.2]], predicted_scores)
+ def testTwoFeatureColumnsMix(self):
+ """Tests predict with two feature columns."""
+ with ops.Graph().as_default():
+ variables_lib.Variable([[10.]], name='linear/linear_model/x0/weights')
+ variables_lib.Variable([[20.]], name='linear/linear_model/x1/weights')
+ variables_lib.Variable([.2], name=BIAS_NAME)
+ variables_lib.Variable(100, name='global_step', dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ linear_regressor = self._linear_regressor_fn(
+ feature_columns=(feature_column.numeric_column('x0'),
+ feature_column_v2.numeric_column('x1')),
+ model_dir=self._model_dir)
+
+ def _predict_input_fn():
+ return dataset_ops.Dataset.from_tensor_slices({
+ 'x0': np.array([[2.]]),
+ 'x1': np.array([[3.]])
+ }).batch(1)
+
+ predictions = linear_regressor.predict(input_fn=_predict_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ # x0 * weight0 + x1 * weight1 + bias = 2. * 10. + 3. * 20 + .2 = 80.2
+ self.assertAllClose([[80.2]], predicted_scores)
+
def testSparseCombiner(self):
w_a = 2.0
w_b = 3.0
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index ac53a84eef..82acde584e 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -54,6 +54,7 @@ py_library(
srcs = ["feature_column_v2.py"],
srcs_version = "PY2AND3",
deps = [
+ ":feature_column",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 28a8286544..8a11ca142c 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -121,6 +121,10 @@ Example of building model using FeatureColumns, this can be used in a
NOTE: Functions prefixed with "_" indicate experimental or private parts of
the API subject to change, and should not be relied upon!
+
+NOTE: The new feature columns are being developed in feature_column_v2.py and
+are a somewhat duplicate of the code here. Please make sure to update logic
+in both places.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index b79373c475..6d089de991 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -136,6 +136,7 @@ import six
from tensorflow.python.eager import context
+from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -157,9 +158,16 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
+_FEATURE_COLUMN_DEPRECATION_DATE = '2018-11-30'
+_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
+ 'deprecated. Please use the new FeatureColumn '
+ 'APIs instead.')
+
+
class StateManager(object):
"""Manages the state associated with FeatureColumns.
@@ -440,10 +448,6 @@ class FeatureLayer(Layer):
return (input_shape[0], total_elements)
-def _strip_leading_slashes(name):
- return name.rsplit('/', 1)[-1]
-
-
class LinearModel(Layer):
"""Produces a linear prediction `Tensor` based on given `feature_columns`.
@@ -775,12 +779,12 @@ def embedding_column(
categorical_column, dimension, combiner='mean', initializer=None,
ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None,
trainable=True):
- """`_DenseColumn` that converts from sparse, categorical input.
+ """`DenseColumn` that converts from sparse, categorical input.
Use this when your inputs are sparse, but you want to convert them to a dense
representation (e.g., to feed to a DNN).
- Inputs must be a `_CategoricalColumn` created by any of the
+ Inputs must be a `CategoricalColumn` created by any of the
`categorical_column_*` function. Here is an example of using
`embedding_column` with `DNNClassifier`:
@@ -814,12 +818,12 @@ def embedding_column(
```
Args:
- categorical_column: A `_CategoricalColumn` created by a
+ categorical_column: A `CategoricalColumn` created by a
`categorical_column_with_*` function. This column produces the sparse IDs
that are inputs to the embedding lookup.
dimension: An integer specifying dimension of the embedding, must be > 0.
- combiner: A string specifying how to reduce if there are multiple entries
- in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
+ combiner: A string specifying how to reduce if there are multiple entries in
+ a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
'mean' the default. 'sqrtn' often achieves good accuracy, in particular
with bag-of-words columns. Each of this can be thought as example level
normalizations on the column. For more information, see
@@ -830,14 +834,14 @@ def embedding_column(
`1/sqrt(dimension)`.
ckpt_to_load_from: String representing checkpoint name/pattern from which to
restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
- tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
- which to restore the column weights. Required if `ckpt_to_load_from` is
- not `None`.
+ tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
+ to restore the column weights. Required if `ckpt_to_load_from` is not
+ `None`.
max_norm: If not `None`, embedding values are l2-normalized to this value.
trainable: Whether or not the embedding is trainable. Default is True.
Returns:
- `_DenseColumn` that converts from sparse input.
+ `DenseColumn` that converts from sparse input.
Raises:
ValueError: if `dimension` not > 0.
@@ -1181,7 +1185,7 @@ def bucketized_column(source_column, boundaries):
one-dimensional.
ValueError: If `boundaries` is not a sorted list or tuple.
"""
- if not isinstance(source_column, NumericColumn):
+ if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)): # pylint: disable=protected-access
raise ValueError(
'source_column must be a column generated with numeric_column(). '
'Given: {}'.format(source_column))
@@ -1390,7 +1394,7 @@ def categorical_column_with_vocabulary_file(key,
def categorical_column_with_vocabulary_list(
key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0):
- """A `_CategoricalColumn` with in-memory vocabulary.
+ """A `CategoricalColumn` with in-memory vocabulary.
Use this when your inputs are in string or integer format, and you have an
in-memory vocabulary mapping each value to an integer ID. By default,
@@ -1439,14 +1443,14 @@ def categorical_column_with_vocabulary_list(
```
Args:
- key: A unique string identifying the input feature. It is used as the
- column name and the dictionary key for feature parsing configs, feature
- `Tensor` objects, and feature columns.
+ key: A unique string identifying the input feature. It is used as the column
+ name and the dictionary key for feature parsing configs, feature `Tensor`
+ objects, and feature columns.
vocabulary_list: An ordered iterable defining the vocabulary. Each feature
is mapped to the index of its value (if present) in `vocabulary_list`.
Must be castable to `dtype`.
- dtype: The type of features. Only string and integer types are supported.
- If `None`, it will be inferred from `vocabulary_list`.
+ dtype: The type of features. Only string and integer types are supported. If
+ `None`, it will be inferred from `vocabulary_list`.
default_value: The integer ID value to return for out-of-vocabulary feature
values, defaults to `-1`. This can not be specified with a positive
`num_oov_buckets`.
@@ -1604,7 +1608,7 @@ def indicator_column(categorical_column):
def weighted_categorical_column(
categorical_column, weight_feature_key, dtype=dtypes.float32):
- """Applies weight values to a `_CategoricalColumn`.
+ """Applies weight values to a `CategoricalColumn`.
Use this when each of your sparse inputs has both an ID and a value. For
example, if you're representing text documents as a collection of word
@@ -1655,7 +1659,7 @@ def weighted_categorical_column(
the same indices and dense shape.
Args:
- categorical_column: A `_CategoricalColumn` created by
+ categorical_column: A `CategoricalColumn` created by
`categorical_column_with_*` functions.
weight_feature_key: String key for weight values.
dtype: Type of weights, such as `tf.float32`. Only float and integer weights
@@ -1788,12 +1792,13 @@ def crossed_column(keys, hash_bucket_size, hash_key=None):
'keys must be a list with length > 1. Given: {}'.format(keys))
for key in keys:
if (not isinstance(key, six.string_types) and
- not isinstance(key, CategoricalColumn)):
+ not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))): # pylint: disable=protected-access
raise ValueError(
'Unsupported key type. All keys must be either string, or '
'categorical column except HashedCategoricalColumn. '
'Given: {}'.format(key))
- if isinstance(key, HashedCategoricalColumn):
+ if isinstance(key,
+ (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)): # pylint: disable=protected-access
raise ValueError(
'categorical_column_with_hash_bucket is not supported for crossing. '
'Hashing before crossing will increase probability of collision. '
@@ -1882,6 +1887,16 @@ class FeatureColumn(object):
"""
pass
+ @abc.abstractproperty
+ def _is_v2_column(self):
+ """Returns whether this FeatureColumn is fully conformant to the new API.
+
+ This is needed for composition type cases where an EmbeddingColumn etc.
+ might take in old categorical columns as input and then we want to use the
+ old API.
+ """
+ pass
+
class DenseColumn(FeatureColumn):
"""Represents a column which can be represented as `Tensor`.
@@ -1927,6 +1942,8 @@ def is_feature_column_v2(feature_columns):
for feature_column in feature_columns:
if not isinstance(feature_column, FeatureColumn):
return False
+ if not feature_column._is_v2_column: # pylint: disable=protected-access
+ return False
return True
@@ -2202,19 +2219,6 @@ class FeatureTransformationCache(object):
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
-def _shape_offsets(shape):
- """Returns moving offset for each dimension given shape."""
- offsets = []
- for dim in reversed(shape):
- if offsets:
- offsets.append(dim * offsets[-1])
- else:
- offsets.append(dim)
- offsets.reverse()
- return offsets
-
-
-# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
"""Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
@@ -2306,12 +2310,17 @@ def _normalize_feature_columns(feature_columns):
class NumericColumn(
DenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
collections.namedtuple(
'NumericColumn',
('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
"""see `numeric_column`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2325,6 +2334,27 @@ class NumericColumn(
self.default_value)
}
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
+ def _transform_input_tensor(self, input_tensor):
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ raise ValueError(
+ 'The corresponding Tensor of numerical column must be a Tensor. '
+ 'SparseTensor is not supported. key: {}'.format(self.key))
+ if self.normalizer_fn is not None:
+ input_tensor = self.normalizer_fn(input_tensor)
+ return math_ops.to_float(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = inputs.get(self.key)
+ return self._transform_input_tensor(input_tensor)
+
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class.
@@ -2342,19 +2372,19 @@ class NumericColumn(
ValueError: If a SparseTensor is passed in.
"""
input_tensor = transformation_cache.get(self.key, state_manager)
- if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
- raise ValueError(
- 'The corresponding Tensor of numerical column must be a Tensor. '
- 'SparseTensor is not supported. key: {}'.format(self.key))
- if self.normalizer_fn is not None:
- input_tensor = self.normalizer_fn(input_tensor)
- return math_ops.to_float(input_tensor)
+ return self._transform_input_tensor(input_tensor)
@property
def variable_shape(self):
"""See `DenseColumn` base class."""
return tensor_shape.TensorShape(self.shape)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
def get_dense_tensor(self, transformation_cache, state_manager):
"""Returns dense `Tensor` representing numeric feature.
@@ -2371,13 +2401,29 @@ class NumericColumn(
# representation created by _transform_feature.
return transformation_cache.get(self, state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ return inputs.get(self)
+
-class BucketizedColumn(DenseColumn, CategoricalColumn,
- collections.namedtuple('BucketizedColumn',
- ('source_column', 'boundaries'))):
+class BucketizedColumn(
+ DenseColumn,
+ CategoricalColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple('BucketizedColumn',
+ ('source_column', 'boundaries'))):
"""See `bucketized_column`."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.source_column, FeatureColumn) and
+ self.source_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_bucketized'.format(self.source_column.name)
@@ -2387,6 +2433,21 @@ class BucketizedColumn(DenseColumn, CategoricalColumn,
"""See `FeatureColumn` base class."""
return self.source_column.parse_example_spec
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.source_column._parse_example_spec # pylint: disable=protected-access
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Returns bucketized categorical `source_column` tensor."""
+ source_tensor = inputs.get(self.source_column)
+ return math_ops._bucketize( # pylint: disable=protected-access
+ source_tensor,
+ boundaries=self.boundaries)
+
def transform_feature(self, transformation_cache, state_manager):
"""Returns bucketized categorical `source_column` tensor."""
source_tensor = transformation_cache.get(self.source_column, state_manager)
@@ -2400,24 +2461,45 @@ class BucketizedColumn(DenseColumn, CategoricalColumn,
return tensor_shape.TensorShape(
tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
- def get_dense_tensor(self, transformation_cache, state_manager):
- """Returns one hot encoded dense `Tensor`."""
- input_tensor = transformation_cache.get(self, state_manager)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
+ def _get_dense_tensor_for_input_tensor(self, input_tensor):
return array_ops.one_hot(
indices=math_ops.to_int64(input_tensor),
depth=len(self.boundaries) + 1,
on_value=1.,
off_value=0.)
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns one hot encoded dense `Tensor`."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return self._get_dense_tensor_for_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return self._get_dense_tensor_for_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""See `CategoricalColumn` base class."""
# By construction, source_column is always one-dimensional.
return (len(self.boundaries) + 1) * self.source_column.shape[0]
- def get_sparse_tensors(self, transformation_cache, state_manager):
- """Converts dense inputs to SparseTensor so downstream code can use it."""
- input_tensor = transformation_cache.get(self, state_manager)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
+ def _get_sparse_tensors_for_input_tensor(self, input_tensor):
batch_size = array_ops.shape(input_tensor)[0]
# By construction, source_column is always one-dimensional.
source_dimension = self.source_column.shape[0]
@@ -2443,9 +2525,27 @@ class BucketizedColumn(DenseColumn, CategoricalColumn,
dense_shape=dense_shape)
return CategoricalColumn.IdWeightPair(sparse_tensor, None)
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return self._get_sparse_tensors_for_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return self._get_sparse_tensors_for_input_tensor(input_tensor)
+
class EmbeddingColumn(
- DenseColumn, SequenceDenseColumn,
+ DenseColumn,
+ SequenceDenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._SequenceDenseColumn, # pylint: disable=protected-access
collections.namedtuple(
'EmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
@@ -2453,6 +2553,11 @@ class EmbeddingColumn(
"""See `embedding_column`."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_embedding'.format(self.categorical_column.name)
@@ -2462,18 +2567,35 @@ class EmbeddingColumn(
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
def transform_feature(self, transformation_cache, state_manager):
"""Transforms underlying `categorical_column`."""
return transformation_cache.get(self.categorical_column, state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ return inputs.get(self.categorical_column)
+
@property
def variable_shape(self):
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
def create_state(self, state_manager):
"""Creates the embedding lookup variable."""
- embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
state_manager.create_variable(
self,
name='embedding_weights',
@@ -2482,17 +2604,11 @@ class EmbeddingColumn(
trainable=self.trainable,
initializer=self.initializer)
- def _get_dense_tensor_internal(self, transformation_cache, state_manager):
- """Private method that follows the signature of _get_dense_tensor."""
- # Get sparse IDs and weights.
- sparse_tensors = self.categorical_column.get_sparse_tensors(
- transformation_cache, state_manager)
+ def _get_dense_tensor_internal_helper(self, sparse_tensors,
+ embedding_weights):
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_weights = state_manager.get_variable(
- self, name='embedding_weights')
-
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):
@@ -2510,6 +2626,30 @@ class EmbeddingColumn(
name='%s_weights' % self.name,
max_norm=self.max_norm)
+ def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
+ """Private method that follows the signature of get_dense_tensor."""
+ embedding_weights = state_manager.get_variable(
+ self, name='embedding_weights')
+ return self._get_dense_tensor_internal_helper(sparse_tensors,
+ embedding_weights)
+
+ def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
+ trainable):
+ """Private method that follows the signature of _get_dense_tensor."""
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ if (weight_collections and
+ ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
+ weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable and trainable,
+ collections=weight_collections)
+ return self._get_dense_tensor_internal_helper(sparse_tensors,
+ embedding_weights)
+
def get_dense_tensor(self, transformation_cache, state_manager):
"""Returns tensor after doing the embedding lookup.
@@ -2535,7 +2675,30 @@ class EmbeddingColumn(
'sequence_input_layer instead of input_layer. '
'Given (type {}): {}'.format(self.name, type(self.categorical_column),
self.categorical_column))
- return self._get_dense_tensor_internal(transformation_cache, state_manager)
+ # Get sparse IDs and weights.
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ return self._get_dense_tensor_internal(sparse_tensors, state_manager)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ if isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access
+ inputs, weight_collections, trainable)
+ return self._old_get_dense_tensor_internal(sparse_tensors,
+ weight_collections, trainable)
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
"""See `SequenceDenseColumn` base class."""
@@ -2547,21 +2710,40 @@ class EmbeddingColumn(
'Suggested fix: Use one of sequence_categorical_column_with_*. '
'Given (type {}): {}'.format(self.name, type(self.categorical_column),
self.categorical_column))
- dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access
+ sparse_tensors = self.categorical_column.get_sequence_sparse_tensors(
transformation_cache, state_manager)
- sparse_tensors = self.categorical_column.get_sparse_tensors(
- transformation_cache, state_manager)
- sequence_length = _sequence_length_from_sparse_tensor(
+ dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
+ state_manager)
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
sparse_tensors.id_tensor)
return SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=dense_tensor, sequence_length=sequence_length)
-
-def _get_graph_for_variable(var):
- if isinstance(var, variables.PartitionedVariable):
- return list(var)[0].graph
- else:
- return var.graph
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ if not isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ dense_tensor = self._old_get_dense_tensor_internal(
+ sparse_tensors,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
class SharedEmbeddingStateManager(Layer):
@@ -2633,8 +2815,17 @@ def maybe_create_shared_state_manager(feature_columns):
return None
+def _raise_shared_embedding_column_error():
+ raise ValueError('SharedEmbeddingColumns are not supported in '
+ '`linear_model` or `input_layer`. Please use '
+ '`FeatureLayer` or `LinearModel` instead.')
+
+
class SharedEmbeddingColumn(
- DenseColumn, SequenceDenseColumn,
+ DenseColumn,
+ SequenceDenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._SequenceDenseColumn, # pylint: disable=protected-access
collections.namedtuple(
'SharedEmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
@@ -2643,6 +2834,10 @@ class SharedEmbeddingColumn(
"""See `embedding_column`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_shared_embedding'.format(self.categorical_column.name)
@@ -2662,15 +2857,26 @@ class SharedEmbeddingColumn(
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
+ @property
+ def _parse_example_spec(self):
+ return _raise_shared_embedding_column_error()
+
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class."""
return transformation_cache.get(self.categorical_column, state_manager)
+ def _transform_feature(self, inputs):
+ return _raise_shared_embedding_column_error()
+
@property
def variable_shape(self):
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ @property
+ def _variable_shape(self):
+ return _raise_shared_embedding_column_error()
+
def create_state(self, state_manager):
"""Creates the shared embedding lookup variable."""
if not isinstance(state_manager, SharedEmbeddingStateManager):
@@ -2731,6 +2937,9 @@ class SharedEmbeddingColumn(
self.categorical_column))
return self._get_dense_tensor_internal(transformation_cache, state_manager)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ return _raise_shared_embedding_column_error()
+
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
"""See `SequenceDenseColumn` base class."""
if not isinstance(self.categorical_column, SequenceCategoricalColumn):
@@ -2745,11 +2954,17 @@ class SharedEmbeddingColumn(
state_manager)
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
- sequence_length = _sequence_length_from_sparse_tensor(
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
sparse_tensors.id_tensor)
return SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=dense_tensor, sequence_length=sequence_length)
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ return _raise_shared_embedding_column_error()
+
def _create_tuple(shape, value):
"""Returns a tuple with given shape and filled with value."""
@@ -2858,11 +3073,16 @@ def _check_default_value(shape, default_value, dtype, key):
class HashedCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('HashedCategoricalColumn',
('key', 'hash_bucket_size', 'dtype'))):
"""see `categorical_column_with_hash_bucket`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2872,10 +3092,14 @@ class HashedCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
- def transform_feature(self, transformation_cache, state_manager):
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
+ def _transform_input_tensor(self, input_tensor):
"""Hashes the values in the feature_column."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
raise ValueError('SparseColumn input must be a SparseTensor.')
@@ -2899,25 +3123,56 @@ class HashedCategoricalColumn(
return sparse_tensor_lib.SparseTensor(
input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Hashes the values in the feature_column."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.hash_bucket_size
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class VocabularyFileCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('VocabularyFileCategoricalColumn',
('key', 'vocabulary_file', 'vocabulary_size',
'num_oov_buckets', 'dtype', 'default_value'))):
"""See `categorical_column_with_vocabulary_file`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2927,11 +3182,14 @@ class VocabularyFileCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
- def transform_feature(self, transformation_cache, state_manager):
- """Creates a lookup table for the vocabulary."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+ def _transform_input_tensor(self, input_tensor):
+ """Creates a lookup table for the vocabulary."""
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
'Column dtype and SparseTensors dtype must be compatible. '
@@ -2957,19 +3215,46 @@ class VocabularyFileCategoricalColumn(
key_dtype=key_dtype,
name='{}_lookup'.format(self.key)).lookup(input_tensor)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Creates a lookup table for the vocabulary."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.vocabulary_size + self.num_oov_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class VocabularyListCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple(
'VocabularyListCategoricalColumn',
('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
@@ -2977,6 +3262,10 @@ class VocabularyListCategoricalColumn(
"""See `categorical_column_with_vocabulary_list`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2986,11 +3275,14 @@ class VocabularyListCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
- def transform_feature(self, transformation_cache, state_manager):
- """Creates a lookup table for the vocabulary list."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+ def _transform_input_tensor(self, input_tensor):
+ """Creates a lookup table for the vocabulary list."""
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
'Column dtype and SparseTensors dtype must be compatible. '
@@ -3015,25 +3307,56 @@ class VocabularyListCategoricalColumn(
dtype=key_dtype,
name='{}_lookup'.format(self.key)).lookup(input_tensor)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Creates a lookup table for the vocabulary list."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return len(self.vocabulary_list) + self.num_oov_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class IdentityCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('IdentityCategoricalColumn',
('key', 'number_buckets', 'default_value'))):
"""See `categorical_column_with_identity`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -3043,11 +3366,14 @@ class IdentityCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
- def transform_feature(self, transformation_cache, state_manager):
- """Returns a SparseTensor with identity values."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+ def _transform_input_tensor(self, input_tensor):
+ """Returns a SparseTensor with identity values."""
if not input_tensor.dtype.is_integer:
raise ValueError(
'Invalid input, not integer. key: {} dtype: {}'.format(
@@ -3082,25 +3408,57 @@ class IdentityCategoricalColumn(
values=values,
dense_shape=input_tensor.dense_shape)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns a SparseTensor with identity values."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.number_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class WeightedCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple(
'WeightedCategoricalColumn',
('categorical_column', 'weight_feature_key', 'dtype'))):
"""See `weighted_categorical_column`."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_weighted_by_{}'.format(
@@ -3117,14 +3475,28 @@ class WeightedCategoricalColumn(
return config
@property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ config = self.categorical_column._parse_example_spec # pylint: disable=protected-access
+ if self.weight_feature_key in config:
+ raise ValueError('Parse config {} already exists for {}.'.format(
+ config[self.weight_feature_key], self.weight_feature_key))
+ config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
+ return config
+
+ @property
def num_buckets(self):
"""See `DenseColumn` base class."""
return self.categorical_column.num_buckets
- def transform_feature(self, transformation_cache, state_manager):
- """Applies weights to tensor generated from `categorical_column`'."""
- weight_tensor = transformation_cache.get(self.weight_feature_key,
- state_manager)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.categorical_column._num_buckets # pylint: disable=protected-access
+
+ def _transform_weight_tensor(self, weight_tensor):
if weight_tensor is None:
raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
@@ -3138,27 +3510,63 @@ class WeightedCategoricalColumn(
weight_tensor, ignore_value=0.0)
if not weight_tensor.dtype.is_floating:
weight_tensor = math_ops.to_float(weight_tensor)
+ return weight_tensor
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Applies weights to tensor generated from `categorical_column`'."""
+ weight_tensor = transformation_cache.get(self.weight_feature_key,
+ state_manager)
+ weight_tensor = self._transform_weight_tensor(weight_tensor)
return (transformation_cache.get(self.categorical_column, state_manager),
weight_tensor)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Applies weights to tensor generated from `categorical_column`'."""
+ weight_tensor = inputs.get(self.weight_feature_key)
+ weight_tensor = self._transform_weight_tensor(weight_tensor)
+ return (inputs.get(self.categorical_column), weight_tensor)
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
tensors = transformation_cache.get(self, state_manager)
return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ tensors = inputs.get(self)
+ return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
+
class CrossedColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('CrossedColumn',
('keys', 'hash_bucket_size', 'hash_key'))):
"""See `crossed_column`."""
@property
+ def _is_v2_column(self):
+ for key in _collect_leaf_level_keys(self):
+ if isinstance(key, six.string_types):
+ continue
+ if not isinstance(key, FeatureColumn):
+ return False
+ if not key._is_v2_column: # pylint: disable=protected-access
+ return False
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
feature_names = []
for key in _collect_leaf_level_keys(self):
- if isinstance(key, FeatureColumn):
+ if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)): # pylint: disable=protected-access
feature_names.append(key.name)
else: # key must be a string
feature_names.append(key)
@@ -3171,17 +3579,25 @@ class CrossedColumn(
for key in self.keys:
if isinstance(key, FeatureColumn):
config.update(key.parse_example_spec)
+ elif isinstance(key, fc_old._FeatureColumn): # pylint: disable=protected-access
+ config.update(key._parse_example_spec) # pylint: disable=protected-access
else: # key must be a string
config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
return config
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
def transform_feature(self, transformation_cache, state_manager):
"""Generates a hashed sparse cross from the input tensors."""
feature_tensors = []
for key in _collect_leaf_level_keys(self):
if isinstance(key, six.string_types):
feature_tensors.append(transformation_cache.get(key, state_manager))
- elif isinstance(key, CategoricalColumn):
+ elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)): # pylint: disable=protected-access
ids_and_weights = key.get_sparse_tensors(transformation_cache,
state_manager)
if ids_and_weights.weight_tensor is not None:
@@ -3197,16 +3613,54 @@ class CrossedColumn(
num_buckets=self.hash_bucket_size,
hash_key=self.hash_key)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Generates a hashed sparse cross from the input tensors."""
+ feature_tensors = []
+ for key in _collect_leaf_level_keys(self):
+ if isinstance(key, six.string_types):
+ feature_tensors.append(inputs.get(key))
+ elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access
+ ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ if ids_and_weights.weight_tensor is not None:
+ raise ValueError(
+ 'crossed_column does not support weight_tensor, but the given '
+ 'column populates weight_tensor. '
+ 'Given column: {}'.format(key.name))
+ feature_tensors.append(ids_and_weights.id_tensor)
+ else:
+ raise ValueError('Unsupported column type. Given: {}'.format(key))
+ return sparse_ops.sparse_cross_hashed(
+ inputs=feature_tensors,
+ num_buckets=self.hash_bucket_size,
+ hash_key=self.hash_key)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.hash_bucket_size
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ """See `CategoricalColumn` base class."""
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
def _collect_leaf_level_keys(cross):
"""Collects base keys by expanding all nested crosses.
@@ -3382,9 +3836,12 @@ def _prune_invalid_weights(sparse_ids, sparse_weights):
return sparse_ids, sparse_weights
-class IndicatorColumn(DenseColumn, SequenceDenseColumn,
- collections.namedtuple('IndicatorColumn',
- ('categorical_column'))):
+class IndicatorColumn(
+ DenseColumn,
+ SequenceDenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._SequenceDenseColumn, # pylint: disable=protected-access
+ collections.namedtuple('IndicatorColumn', ('categorical_column'))):
"""Represents a one-hot column for use in deep networks.
Args:
@@ -3393,27 +3850,16 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
"""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_indicator'.format(self.categorical_column.name)
- def transform_feature(self, transformation_cache, state_manager):
- """Returns dense `Tensor` representing feature.
-
- Args:
- transformation_cache: A `FeatureTransformationCache` object to access
- features.
- state_manager: A `StateManager` to create / access resources such as
- lookup tables.
-
- Returns:
- Transformed feature `Tensor`.
-
- Raises:
- ValueError: if input rank is not known at graph building time.
- """
- id_weight_pair = self.categorical_column.get_sparse_tensors(
- transformation_cache, state_manager)
+ def _transform_id_weight_pair(self, id_weight_pair):
id_tensor = id_weight_pair.id_tensor
weight_tensor = id_weight_pair.weight_tensor
@@ -3422,7 +3868,7 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
weighted_column = sparse_ops.sparse_merge(
sp_ids=id_tensor,
sp_values=weight_tensor,
- vocab_size=int(self.variable_shape[-1]))
+ vocab_size=int(self._variable_shape[-1]))
# Remove (?, -1) index
weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
weighted_column.dense_shape)
@@ -3435,22 +3881,62 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
# input_layer are float32.
one_hot_id_tensor = array_ops.one_hot(
dense_id_tensor,
- depth=self.variable_shape[-1],
+ depth=self._variable_shape[-1],
on_value=1.0,
off_value=0.0)
# Reduce to get a multi-hot per example.
return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns dense `Tensor` representing feature.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Transformed feature `Tensor`.
+
+ Raises:
+ ValueError: if input rank is not known at graph building time.
+ """
+ id_weight_pair = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ return self._transform_id_weight_pair(id_weight_pair)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ id_weight_pair = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ return self._transform_id_weight_pair(id_weight_pair)
+
@property
def parse_example_spec(self):
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
@property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
+ @property
def variable_shape(self):
"""Returns a `TensorShape` representing the shape of the dense `Tensor`."""
- return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
+ if isinstance(self.categorical_column, FeatureColumn):
+ return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
+ else:
+ return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access
+
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access
def get_dense_tensor(self, transformation_cache, state_manager):
"""Returns dense `Tensor` representing feature.
@@ -3481,6 +3967,27 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
# representation created by transform_feature.
return transformation_cache.get(self, state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ if isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by transform_feature.
+ return inputs.get(self)
+
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
"""See `SequenceDenseColumn` base class."""
if not isinstance(self.categorical_column, SequenceCategoricalColumn):
@@ -3496,7 +4003,36 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
dense_tensor = transformation_cache.get(self, state_manager)
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
- sequence_length = _sequence_length_from_sparse_tensor(
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ # Do nothing with weight_collections and trainable since no variables are
+ # created in this function.
+ del weight_collections
+ del trainable
+ if not isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by _transform_feature.
+ dense_tensor = inputs.get(self)
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
sparse_tensors.id_tensor)
return SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=dense_tensor, sequence_length=sequence_length)
@@ -3518,28 +4054,19 @@ def _verify_static_batch_size_equality(tensors, columns):
expected_batch_size, tensors[i].shape[0]))
-def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
- """Returns a [batch_size] Tensor with per-example sequence length."""
- with ops.name_scope(None, 'sequence_length') as name_scope:
- row_ids = sp_tensor.indices[:, 0]
- column_ids = sp_tensor.indices[:, 1]
- column_ids += array_ops.ones_like(column_ids)
- seq_length = math_ops.to_int64(
- math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
- # If the last n rows do not have ids, seq_length will have shape
- # [batch_size - n]. Pad the remaining values with zeros.
- n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
- padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
- return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
-
-
-class SequenceCategoricalColumn(FeatureColumn,
- collections.namedtuple(
- 'SequenceCategoricalColumn',
- ('categorical_column'))):
+class SequenceCategoricalColumn(
+ FeatureColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple('SequenceCategoricalColumn',
+ ('categorical_column'))):
"""Represents sequences of categorical data."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.categorical_column.name
@@ -3549,16 +4076,46 @@ class SequenceCategoricalColumn(FeatureColumn,
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class."""
return self.categorical_column.transform_feature(transformation_cache,
state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.categorical_column.num_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.categorical_column._num_buckets # pylint: disable=protected-access
+
+ def _get_sparse_tensors_helper(self, sparse_tensors):
+ id_tensor = sparse_tensors.id_tensor
+ weight_tensor = sparse_tensors.weight_tensor
+ # Expands third dimension, if necessary so that embeddings are not
+ # combined during embedding lookup. If the tensor is already 3D, leave
+ # as-is.
+ shape = array_ops.shape(id_tensor)
+ target_shape = [shape[0], shape[1], -1]
+ id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
+ if weight_tensor is not None:
+ weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
+ return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+
def get_sequence_sparse_tensors(self, transformation_cache, state_manager):
"""Returns an IdWeightPair.
@@ -3580,27 +4137,11 @@ class SequenceCategoricalColumn(FeatureColumn,
"""
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
- id_tensor = sparse_tensors.id_tensor
- weight_tensor = sparse_tensors.weight_tensor
- # Expands final dimension, so that embeddings are not combined during
- # embedding lookup.
- check_id_rank = check_ops.assert_equal(
- array_ops.rank(id_tensor), 2,
- data=[
- 'Column {} expected ID tensor of rank 2. '.format(self.name),
- 'id_tensor shape: ', array_ops.shape(id_tensor)])
- with ops.control_dependencies([check_id_rank]):
- id_tensor = sparse_ops.sparse_reshape(
- id_tensor,
- shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
- if weight_tensor is not None:
- check_weight_rank = check_ops.assert_equal(
- array_ops.rank(weight_tensor), 2,
- data=[
- 'Column {} expected weight tensor of rank 2.'.format(self.name),
- 'weight_tensor shape:', array_ops.shape(weight_tensor)])
- with ops.control_dependencies([check_weight_rank]):
- weight_tensor = sparse_ops.sparse_reshape(
- weight_tensor,
- shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
- return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+ return self._get_sparse_tensors_helper(sparse_tensors)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ return self._get_sparse_tensors_helper(sparse_tensors)
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index d3787146ed..31bc0485ef 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -31,12 +31,8 @@ from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc
-from tensorflow.python.feature_column.feature_column_v2 import _transform_features
-from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
-from tensorflow.python.feature_column.feature_column_v2 import FeatureLayer
-from tensorflow.python.feature_column.feature_column_v2 import FeatureTransformationCache
-from tensorflow.python.feature_column.feature_column_v2 import StateManager
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -46,6 +42,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
@@ -60,16 +57,30 @@ def _initialized_session(config=None):
return sess
+def get_linear_model_bias(name='linear_model'):
+ with variable_scope.variable_scope(name, reuse=True):
+ return variable_scope.get_variable('bias_weights')
+
+
+def get_linear_model_column_var(column, name='linear_model'):
+ return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ name + '/' + column.name)[0]
+
+
class LazyColumnTest(test.TestCase):
def test_transformations_called_once(self):
- class TransformCounter(FeatureColumn):
+ class TransformCounter(fc.FeatureColumn):
def __init__(self):
self.num_transform = 0
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
return 'TransformCounter'
@@ -81,7 +92,7 @@ class LazyColumnTest(test.TestCase):
def parse_example_spec(self):
pass
- transformation_cache = FeatureTransformationCache(
+ transformation_cache = fc.FeatureTransformationCache(
features={'a': [[2], [3.]]})
column = TransformCounter()
self.assertEqual(0, column.num_transform)
@@ -92,7 +103,11 @@ class LazyColumnTest(test.TestCase):
def test_returns_transform_output(self):
- class Transformer(FeatureColumn):
+ class Transformer(fc.FeatureColumn):
+
+ @property
+ def _is_v2_column(self):
+ return True
@property
def name(self):
@@ -105,7 +120,7 @@ class LazyColumnTest(test.TestCase):
def parse_example_spec(self):
pass
- transformation_cache = FeatureTransformationCache(
+ transformation_cache = fc.FeatureTransformationCache(
features={'a': [[2], [3.]]})
column = Transformer()
self.assertEqual('Output', transformation_cache.get(column, None))
@@ -113,7 +128,11 @@ class LazyColumnTest(test.TestCase):
def test_does_not_pollute_given_features_dict(self):
- class Transformer(FeatureColumn):
+ class Transformer(fc.FeatureColumn):
+
+ @property
+ def _is_v2_column(self):
+ return True
@property
def name(self):
@@ -127,12 +146,12 @@ class LazyColumnTest(test.TestCase):
pass
features = {'a': [[2], [3.]]}
- transformation_cache = FeatureTransformationCache(features=features)
+ transformation_cache = fc.FeatureTransformationCache(features=features)
transformation_cache.get(Transformer(), None)
self.assertEqual(['a'], list(features.keys()))
def test_error_if_feature_is_not_found(self):
- transformation_cache = FeatureTransformationCache(
+ transformation_cache = fc.FeatureTransformationCache(
features={'a': [[2], [3.]]})
with self.assertRaisesRegexp(ValueError,
'bbb is not in features dictionary'):
@@ -143,7 +162,11 @@ class LazyColumnTest(test.TestCase):
def test_not_supported_feature_column(self):
- class NotAProperColumn(FeatureColumn):
+ class NotAProperColumn(fc.FeatureColumn):
+
+ @property
+ def _is_v2_column(self):
+ return True
@property
def name(self):
@@ -157,7 +180,7 @@ class LazyColumnTest(test.TestCase):
def parse_example_spec(self):
pass
- transformation_cache = FeatureTransformationCache(
+ transformation_cache = fc.FeatureTransformationCache(
features={'a': [[2], [3.]]})
with self.assertRaisesRegexp(ValueError,
'NotAProperColumn is not supported'):
@@ -168,7 +191,7 @@ class LazyColumnTest(test.TestCase):
class NotAFeatureColumn(object):
pass
- transformation_cache = FeatureTransformationCache(
+ transformation_cache = fc.FeatureTransformationCache(
features={'a': [[2], [3.]]})
with self.assertRaisesRegexp(
TypeError, '"key" must be either a "str" or "FeatureColumn".'):
@@ -176,7 +199,7 @@ class LazyColumnTest(test.TestCase):
def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
# empty 1-D sparse tensor:
- transformation_cache = FeatureTransformationCache(
+ transformation_cache = fc.FeatureTransformationCache(
features={
'a':
sparse_tensor.SparseTensor(
@@ -201,6 +224,7 @@ class NumericColumnTest(test.TestCase):
self.assertIsNone(a.default_value)
self.assertEqual(dtypes.float32, a.dtype)
self.assertIsNone(a.normalizer_fn)
+ self.assertTrue(a._is_v2_column)
def test_key_should_be_string(self):
with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
@@ -317,7 +341,9 @@ class NumericColumnTest(test.TestCase):
return input_tensor + 2.
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
- output = _transform_features({'price': [[1., 2.], [5., 6.]]}, [price], None)
+ output = fc._transform_features({
+ 'price': [[1., 2.], [5., 6.]]
+ }, [price], None)
with self.cached_session():
self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
@@ -327,7 +353,7 @@ class NumericColumnTest(test.TestCase):
return input_tensor + 2.
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'price': [[1., 2.], [5., 6.]]
})
self.assertEqual(
@@ -336,7 +362,7 @@ class NumericColumnTest(test.TestCase):
def test_sparse_tensor_not_supported(self):
price = fc.numeric_column('price')
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'price':
sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
@@ -370,6 +396,20 @@ class NumericColumnTest(test.TestCase):
sess.run(price_var.assign([[10.]]))
self.assertAllClose([[10.], [50.]], predictions.eval())
+ def test_old_linear_model(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = fc_old.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[10.], [50.]], predictions.eval())
+
class BucketizedColumnTest(test.TestCase):
@@ -404,6 +444,13 @@ class BucketizedColumnTest(test.TestCase):
def test_name(self):
a = fc.numeric_column('aaa', dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
+ self.assertTrue(b._is_v2_column)
+ self.assertEqual('aaa_bucketized', b.name)
+
+ def test_is_v2_column_old_numeric(self):
+ a = fc_old.numeric_column('aaa', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ self.assertFalse(b._is_v2_column)
self.assertEqual('aaa_bucketized', b.name)
def test_parse_spec(self):
@@ -445,7 +492,7 @@ class BucketizedColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
- transformed_tensor = _transform_features({
+ transformed_tensor = fc._transform_features({
'price': [[-1., 1.], [5., 6.]]
}, [bucketized_price], None)
with _initialized_session():
@@ -457,7 +504,7 @@ class BucketizedColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[1])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'price': [[-1.], [1.], [5.], [6.]]
})
with _initialized_session():
@@ -476,7 +523,7 @@ class BucketizedColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'price': [[-1., 1.], [5., 6.]]
})
with _initialized_session():
@@ -493,7 +540,7 @@ class BucketizedColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[1])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'price': [[-1.], [1.], [5.], [6.]]
})
with _initialized_session() as sess:
@@ -511,7 +558,7 @@ class BucketizedColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'price': [[-1., 1.], [5., 6.]]
})
with _initialized_session() as sess:
@@ -529,7 +576,7 @@ class BucketizedColumnTest(test.TestCase):
def test_sparse_tensor_input_not_supported(self):
price = fc.numeric_column('price')
bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'price':
sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
@@ -599,6 +646,85 @@ class BucketizedColumnTest(test.TestCase):
sess.run(bias.assign([1.]))
self.assertAllClose([[81.], [141.]], predictions.eval())
+ def test_old_linear_model_one_input_value(self):
+ """Tests linear_model() for input with shape=[1]."""
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1.], [1.], [5.], [6.]]}
+ predictions = fc_old.linear_model(features, [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight variable per bucket, all initialized to zero.
+ self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 1st bucket, whose weight is 20.
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 4th bucket, whose weight is 50.
+ self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
+
+ def test_old_linear_model_two_input_values(self):
+ """Tests linear_model() for input with shape=[2]."""
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1., 1.], [5., 6.]]}
+ predictions = fc_old.linear_model(features, [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight per bucket per input column, all initialized to zero.
+ self.assertAllClose(
+ [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],
+ [60.], [70.], [80.], [90.], [100.]]))
+ # 1st example:
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 6th bucket, whose weight is 70.
+ # 2nd example:
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 9th bucket, whose weight is 100.
+ self.assertAllClose([[80.], [140.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[81.], [141.]], predictions.eval())
+
+ def test_old_linear_model_one_input_value_old_numeric(self):
+ """Tests linear_model() for input with shape=[1]."""
+ price = fc_old.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1.], [1.], [5.], [6.]]}
+ predictions = fc_old.linear_model(features, [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight variable per bucket, all initialized to zero.
+ self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 1st bucket, whose weight is 20.
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 4th bucket, whose weight is 50.
+ self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
+
class HashedCategoricalColumnTest(test.TestCase):
@@ -608,6 +734,7 @@ class HashedCategoricalColumnTest(test.TestCase):
self.assertEqual('aaa', a.key)
self.assertEqual(10, a.hash_bucket_size)
self.assertEqual(dtypes.string, a.dtype)
+ self.assertTrue(a._is_v2_column)
def test_key_should_be_string(self):
with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
@@ -675,7 +802,9 @@ class HashedCategoricalColumnTest(test.TestCase):
values=['omar', 'stringer', 'marlo'],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
- outputs = _transform_features({'wire': wire_tensor}, [hashed_sparse], None)
+ outputs = fc._transform_features({
+ 'wire': wire_tensor
+ }, [hashed_sparse], None)
output = outputs[hashed_sparse]
# Check exact hashed output. If hashing changes this test will break.
expected_values = [6, 4, 1]
@@ -705,7 +834,7 @@ class HashedCategoricalColumnTest(test.TestCase):
values=[101.],
indices=[[0, 0]],
dense_shape=[1, 1])
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'a_int': int_tensor,
'a_string': string_tensor,
'a_float': float_tensor
@@ -720,7 +849,7 @@ class HashedCategoricalColumnTest(test.TestCase):
'wire', 10, dtype=dtypes.int64)
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- transformation_cache = FeatureTransformationCache({'wire': wire_tensor})
+ transformation_cache = fc.FeatureTransformationCache({'wire': wire_tensor})
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
transformation_cache.get(hashed_sparse, None)
@@ -731,7 +860,7 @@ class HashedCategoricalColumnTest(test.TestCase):
values=[101, 201, 301],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
- transformation_cache = FeatureTransformationCache({'wire': wire_tensor})
+ transformation_cache = fc.FeatureTransformationCache({'wire': wire_tensor})
output = transformation_cache.get(hashed_sparse, None)
# Check exact hashed output. If hashing changes this test will break.
expected_values = [3, 7, 5]
@@ -745,7 +874,7 @@ class HashedCategoricalColumnTest(test.TestCase):
values=constant_op.constant([101, 201, 301], dtype=dtypes.int32),
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
- transformation_cache = FeatureTransformationCache({'wire': wire_tensor})
+ transformation_cache = fc.FeatureTransformationCache({'wire': wire_tensor})
output = transformation_cache.get(hashed_sparse, None)
# Check exact hashed output. If hashing changes this test will break.
expected_values = [3, 7, 5]
@@ -754,7 +883,7 @@ class HashedCategoricalColumnTest(test.TestCase):
def test_get_sparse_tensors(self):
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'wire':
sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'],
@@ -769,7 +898,7 @@ class HashedCategoricalColumnTest(test.TestCase):
def test_get_sparse_tensors_dense_input(self):
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'wire': (('omar', ''), ('stringer', 'marlo'))
})
id_weight_pair = hashed_sparse.get_sparse_tensors(transformation_cache,
@@ -800,6 +929,28 @@ class HashedCategoricalColumnTest(test.TestCase):
# 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
self.assertAllClose(((4.,), (6.,)), predictions.eval())
+ def test_old_linear_model(self):
+ wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column.num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 3: wire_var[3] = 4
+ # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
+ self.assertAllClose(((4.,), (6.,)), predictions.eval())
+
class CrossedColumnTest(test.TestCase):
@@ -841,8 +992,20 @@ class CrossedColumnTest(test.TestCase):
a = fc.numeric_column('a', dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
crossed1 = fc.crossed_column(['d1', 'd2'], 10)
+ self.assertTrue(crossed1._is_v2_column)
+
+ crossed2 = fc.crossed_column([b, 'c', crossed1], 10)
+ self.assertTrue(crossed2._is_v2_column)
+ self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
+
+ def test_is_v2_column(self):
+ a = fc_old.numeric_column('a', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed1 = fc.crossed_column(['d1', 'd2'], 10)
+ self.assertTrue(crossed1._is_v2_column)
crossed2 = fc.crossed_column([b, 'c', crossed1], 10)
+ self.assertFalse(crossed2._is_v2_column)
self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
def test_name_ordered_alphabetically(self):
@@ -927,7 +1090,7 @@ class CrossedColumnTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2]),
}
- outputs = _transform_features(features, [price_cross_wire], None)
+ outputs = fc._transform_features(features, [price_cross_wire], None)
output = outputs[price_cross_wire]
with self.cached_session() as sess:
output_val = sess.run(output)
@@ -943,7 +1106,7 @@ class CrossedColumnTest(test.TestCase):
crossed1 = fc.crossed_column(['d1', 'd2'], 10)
crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
with ops.Graph().as_default():
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'a':
constant_op.constant(((-1., .5), (.5, 1.))),
'c':
@@ -983,7 +1146,7 @@ class CrossedColumnTest(test.TestCase):
b = fc.bucketized_column(a, boundaries=(0, 1))
crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'a':
constant_op.constant(((-1., .5), (.5, 1.))),
'c':
@@ -1041,6 +1204,10 @@ class CrossedColumnTest(test.TestCase):
"""Produces sparse IDs and sparse weights."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
return 'test_column'
@@ -1092,6 +1259,146 @@ class CrossedColumnTest(test.TestCase):
dense_shape=(2, 2)),
})
+ def test_old_linear_model(self):
+ """Tests linear_model.
+
+ Uses data from test_get_sparse_tesnsors_simple.
+ """
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ 'a':
+ constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+ bias = get_linear_model_bias()
+ crossed_var = get_linear_model_column_var(crossed)
+ with _initialized_session() as sess:
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
+ crossed_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
+ # Expected ids after cross = (1, 0, 1, 3, 4, 2)
+ self.assertAllClose(((3.,), (14.,)), predictions.eval())
+ sess.run(bias.assign((.1,)))
+ self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
+
+ def test_old_linear_model_with_weights(self):
+
+ class _TestColumnWithWeights(fc.CategoricalColumn,
+ fc_old._CategoricalColumn):
+ """Produces sparse IDs and sparse weights."""
+
+ @property
+ def _is_v2_column(self):
+ return True
+
+ @property
+ def name(self):
+ return 'test_column'
+
+ @property
+ def parse_example_spec(self):
+ return {
+ self.name:
+ parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name):
+ parsing_ops.VarLenFeature(dtypes.float32),
+ }
+
+ @property
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
+ @property
+ def num_buckets(self):
+ return 5
+
+ @property
+ def _num_buckets(self):
+ return self.num_buckets
+
+ def transform_feature(self, transformation_cache, state_manager):
+ raise ValueError('Should not be called.')
+
+ def _transform_feature(self, inputs):
+ return (inputs.get(self.name),
+ inputs.get('{}_weights'.format(self.name)))
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ raise ValueError('Should not be called.')
+
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Populates both id_tensor and weight_tensor."""
+ ids_and_weights = inputs.get(self)
+ return fc.CategoricalColumn.IdWeightPair(
+ id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
+
+ t = _TestColumnWithWeights()
+ crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
+ fc_old.linear_model({
+ t.name:
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[0, 1, 2],
+ dense_shape=(2, 2)),
+ '{}_weights'.format(t.name):
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[1., 10., 2.],
+ dense_shape=(2, 2)),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+
+ def test_old_linear_model_old_numeric(self):
+ """Tests linear_model.
+
+ Uses data from test_get_sparse_tesnsors_simple.
+ """
+ a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ 'a':
+ constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+ bias = get_linear_model_bias()
+ crossed_var = get_linear_model_column_var(crossed)
+ with _initialized_session() as sess:
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
+ crossed_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
+ # Expected ids after cross = (1, 0, 1, 3, 4, 2)
+ self.assertAllClose(((3.,), (14.,)), predictions.eval())
+ sess.run(bias.assign((.1,)))
+ self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
+
class LinearModelTest(test.TestCase):
@@ -1109,6 +1416,10 @@ class LinearModelTest(test.TestCase):
class NotSupportedColumn(fc.FeatureColumn):
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
return 'NotSupportedColumn'
@@ -1190,6 +1501,10 @@ class LinearModelTest(test.TestCase):
class _DenseAndSparseColumn(fc.DenseColumn, fc.CategoricalColumn):
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
return 'dense_and_sparse_column'
@@ -1735,12 +2050,868 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[25.], [105.]], predictions2.eval())
+class OldLinearModelTest(test.TestCase):
+
+ def test_raises_if_empty_feature_columns(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'feature_columns must not be empty'):
+ fc_old.linear_model(features={}, feature_columns=[])
+
+ def test_should_be_feature_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
+ fc_old.linear_model(features={'a': [[0]]}, feature_columns='NotSupported')
+
+ def test_should_be_dense_or_categorical_column(self):
+
+ class NotSupportedColumn(fc.FeatureColumn, fc_old._FeatureColumn):
+
+ @property
+ def _is_v2_column(self):
+ return True
+
+ @property
+ def name(self):
+ return 'NotSupportedColumn'
+
+ def transform_feature(self, transformation_cache, state_manager):
+ pass
+
+ def _transform_feature(self, inputs):
+ pass
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ @property
+ def _parse_example_spec(self):
+ pass
+
+ with self.assertRaisesRegexp(
+ ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
+ fc_old.linear_model(
+ features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc_old.linear_model(
+ features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')})
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ fc_old.linear_model(
+ features={'a': [[0]]},
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])
+
+ def test_dense_bias(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = fc_old.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ sess.run(price_var.assign([[10.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions.eval())
+
+ def test_sparse_bias(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc_old.linear_model(features, [wire_cast])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_and_sparse_bias(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
+ predictions = fc_old.linear_model(features, [wire_cast, price])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[1015.], [10065.]], predictions.eval())
+
+ def test_dense_and_sparse_column(self):
+ """When the column is both dense and sparse, uses sparse tensors."""
+
+ class _DenseAndSparseColumn(fc.DenseColumn, fc.CategoricalColumn,
+ fc_old._DenseColumn, fc_old._CategoricalColumn):
+
+ @property
+ def _is_v2_column(self):
+ return True
+
+ @property
+ def name(self):
+ return 'dense_and_sparse_column'
+
+ @property
+ def parse_example_spec(self):
+ return {self.name: parsing_ops.VarLenFeature(self.dtype)}
+
+ @property
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
+ def transform_feature(self, transformation_cache, state_manager):
+ raise ValueError('Should not use this method.')
+
+ def _transform_feature(self, inputs):
+ return inputs.get(self.name)
+
+ @property
+ def variable_shape(self):
+ return self.variable_shape
+
+ @property
+ def _variable_shape(self):
+ return self.variable_shape
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ raise ValueError('Should not use this method.')
+
+ def _get_dense_tensor(self, inputs):
+ raise ValueError('Should not use this method.')
+
+ @property
+ def num_buckets(self):
+ return 4
+
+ @property
+ def _num_buckets(self):
+ return self.num_buckets
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ raise ValueError('Should not use this method.')
+
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ sp_tensor = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[2, 0, 3],
+ dense_shape=[2, 2])
+ return fc.CategoricalColumn.IdWeightPair(sp_tensor, None)
+
+ dense_and_sparse_column = _DenseAndSparseColumn()
+ with ops.Graph().as_default():
+ sp_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {dense_and_sparse_column.name: sp_tensor}
+ predictions = fc_old.linear_model(features, [dense_and_sparse_column])
+ bias = get_linear_model_bias()
+ dense_and_sparse_column_var = get_linear_model_column_var(
+ dense_and_sparse_column)
+ with _initialized_session() as sess:
+ sess.run(
+ dense_and_sparse_column_var.assign([[10.], [100.], [1000.],
+ [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_multi_output(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = fc_old.linear_model(features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((1, 3)), price_var.eval())
+ sess.run(price_var.assign([[10., 100., 1000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
+ predictions.eval())
+
+ def test_sparse_multi_output(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc_old.linear_model(features, [wire_cast], units=3)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
+ sess.run(
+ wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],
+ [1000., 1100., 1200.],
+ [10000., 11000., 12000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
+ predictions.eval())
+
+ def test_dense_multi_dimension(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = fc_old.linear_model(features, [price])
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_sparse_multi_rank(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = array_ops.sparse_placeholder(dtypes.string)
+ wire_value = sparse_tensor.SparseTensorValue(
+ values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
+ indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
+ dense_shape=[2, 2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc_old.linear_model(features, [wire_cast])
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
+ self.assertAllClose(
+ np.zeros((2, 1)),
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ self.assertAllClose(
+ [[1010.], [11000.]],
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+
+ def test_sparse_combiner(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc_old.linear_model(
+ features, [wire_cast], sparse_combiner='mean')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [5010.]], predictions.eval())
+
+ def test_sparse_combiner_with_negative_weights(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast_weights = fc.weighted_categorical_column(wire_cast, 'weights')
+
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {
+ 'wire_cast': wire_tensor,
+ 'weights': constant_op.constant([[1., 1., -1.0]])
+ }
+ predictions = fc_old.linear_model(
+ features, [wire_cast_weights], sparse_combiner='sum')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [-9985.]], predictions.eval())
+
+ def test_dense_multi_dimension_multi_output(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = fc_old.linear_model(features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((2, 3)), price_var.eval())
+ sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
+ sess.run(bias.assign([2., 3., 4.]))
+ self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
+ predictions.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ fc_old.linear_model(features, [price])
+
+ def test_dense_reshaping(self):
+ price = fc.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': [[[1., 2.]], [[5., 6.]]]}
+ predictions = fc_old.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_dense_multi_column(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ predictions = fc_old.linear_model(features, [price1, price2])
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price1_var.eval())
+ self.assertAllClose([[0.]], price2_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price1_var.assign([[10.], [100.]]))
+ sess.run(price2_var.assign([[1000.]]))
+ sess.run(bias.assign([7.]))
+ self.assertAllClose([[3217.], [4657.]], predictions.eval())
+
+ def test_fills_cols_to_vars(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ cols_to_vars = {}
+ fc_old.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ self.assertAllEqual(cols_to_vars['bias'], [bias])
+ self.assertAllEqual(cols_to_vars[price1], [price1_var])
+ self.assertAllEqual(cols_to_vars[price2], [price2_var])
+
+ def test_fills_cols_to_vars_partitioned_variables(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2', shape=3)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [6., 7.]],
+ 'price2': [[3., 4., 5.], [8., 9., 10.]]
+ }
+ cols_to_vars = {}
+ with variable_scope.variable_scope(
+ 'linear',
+ partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
+ fc_old.linear_model(
+ features, [price1, price2], cols_to_vars=cols_to_vars)
+ with _initialized_session():
+ self.assertEqual([0.], cols_to_vars['bias'][0].eval())
+ # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
+ self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
+ self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
+ # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
+ # a [1, 1] Variable.
+ self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
+ self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
+
+ def test_fills_cols_to_output_tensors(self):
+ # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
+ # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn
+ # creates a Variable.
+ apple_numeric_column = fc.numeric_column('apple_numeric_column')
+ banana_dense_feature = fc.numeric_column('banana_dense_feature')
+ banana_dense_feature_bucketized = fc.bucketized_column(
+ banana_dense_feature, boundaries=[0.])
+ cherry_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'cherry_sparse_feature', hash_bucket_size=5)
+ dragonfruit_embedding_column = fc.embedding_column(
+ cherry_sparse_column, dimension=10)
+ with ops.Graph().as_default():
+ features = {
+ 'apple_numeric_column': [[3.], [4.]],
+ 'banana_dense_feature': [[-1.], [4.]],
+ 'cherry_sparse_feature': [['a'], ['x']],
+ }
+ cols_to_output_tensors = {}
+ all_cols = [
+ apple_numeric_column, banana_dense_feature_bucketized,
+ dragonfruit_embedding_column
+ ]
+ input_layer = fc_old.input_layer(
+ features, all_cols, cols_to_output_tensors=cols_to_output_tensors)
+
+ # We check the mapping by checking that we have the right keys,
+ # and that the values (output_tensors) were indeed the ones used to
+ # form the input layer.
+ self.assertItemsEqual(all_cols, cols_to_output_tensors.keys())
+ input_layer_inputs = [tensor for tensor in input_layer.op.inputs[:-1]]
+ output_tensors = [tensor for tensor in cols_to_output_tensors.values()]
+ self.assertItemsEqual(input_layer_inputs, output_tensors)
+
+ def test_dense_collection(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ fc_old.linear_model(features, [price], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ self.assertIn(bias, my_vars)
+ self.assertIn(price_var, my_vars)
+
+ def test_sparse_collection(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc_old.linear_model(features, [wire_cast], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, my_vars)
+ self.assertIn(wire_cast_var, my_vars)
+
+ def test_dense_trainable_default(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ fc_old.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(price_var, trainable_vars)
+
+ def test_sparse_trainable_default(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc_old.linear_model(features, [wire_cast])
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(wire_cast_var, trainable_vars)
+
+ def test_dense_trainable_false(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ fc_old.linear_model(features, [price], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_sparse_trainable_false(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc_old.linear_model(features, [wire_cast], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_column_order(self):
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ fc_old.linear_model(
+ features, [price_a, wire_cast, price_b],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ fc_old.linear_model(
+ features, [wire_cast, price_b, price_a],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ def test_static_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1.], [5.], [7.]], # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc_old.linear_model(features, [price1, price2])
+
+ def test_subset_of_static_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]], # batchsize = 2
+ 'price3': [[3.], [4.], [5.]] # batchsize = 3
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc_old.linear_model(features, [price1, price2, price3])
+
+ def test_runtime_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ predictions = fc_old.linear_model(features, [price1, price2])
+ with _initialized_session() as sess:
+ with self.assertRaisesRegexp(errors.OpError,
+ 'must have the same size and shape'):
+ sess.run(
+ predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+ def test_runtime_batch_size_matches(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ }
+ predictions = fc_old.linear_model(features, [price1, price2])
+ with _initialized_session() as sess:
+ sess.run(
+ predictions,
+ feed_dict={
+ features['price1']: [[1.], [5.]],
+ features['price2']: [[1.], [5.]],
+ })
+
+ def test_with_1d_sparse_tensor(self):
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price':
+ constant_op.constant([
+ -1.,
+ 12.,
+ ]),
+ 'body-style':
+ sparse_tensor.SparseTensor(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,)),
+ }
+ self.assertEqual(1, features['price'].shape.ndims)
+ self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+
+ net = fc_old.linear_model(features, [price_buckets, body_style])
+ with _initialized_session() as sess:
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
+
+ def test_with_1d_unknown_shape_sparse_tensor(self):
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ country = fc.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ 'body-style': array_ops.sparse_placeholder(dtypes.string),
+ 'country': array_ops.placeholder(dtypes.string),
+ }
+ self.assertIsNone(features['price'].shape.ndims)
+ self.assertIsNone(features['body-style'].get_shape().ndims)
+
+ price_data = np.array([-1., 12.])
+ body_style_data = sparse_tensor.SparseTensorValue(
+ indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
+ country_data = np.array(['US', 'CA'])
+
+ net = fc_old.linear_model(features, [price_buckets, body_style, country])
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+ with _initialized_session() as sess:
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
+ sess.run(
+ net,
+ feed_dict={
+ features['price']: price_data,
+ features['body-style']: body_style_data,
+ features['country']: country_data
+ }))
+
+ def test_with_rank_0_feature(self):
+ price = fc.numeric_column('price')
+ features = {
+ 'price': constant_op.constant(0),
+ }
+ self.assertEqual(0, features['price'].shape.ndims)
+
+ # Static rank 0 should fail
+ with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+ fc_old.linear_model(features, [price])
+
+ # Dynamic rank 0 should fail
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ }
+ net = fc_old.linear_model(features, [price])
+ self.assertEqual(1, net.shape[1])
+ with _initialized_session() as sess:
+ with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+ sess.run(net, feed_dict={features['price']: np.array(1)})
+
+ def test_multiple_linear_models(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features1 = {'price': [[1.], [5.]]}
+ features2 = {'price': [[2.], [10.]]}
+ predictions1 = fc_old.linear_model(features1, [price])
+ predictions2 = fc_old.linear_model(features2, [price])
+ bias1 = get_linear_model_bias(name='linear_model')
+ bias2 = get_linear_model_bias(name='linear_model_1')
+ price_var1 = get_linear_model_column_var(price, name='linear_model')
+ price_var2 = get_linear_model_column_var(price, name='linear_model_1')
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias1.eval())
+ sess.run(price_var1.assign([[10.]]))
+ sess.run(bias1.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions1.eval())
+ self.assertAllClose([0.], bias2.eval())
+ sess.run(price_var2.assign([[10.]]))
+ sess.run(bias2.assign([5.]))
+ self.assertAllClose([[25.], [105.]], predictions2.eval())
+
+ def test_linear_model_v1_shared_embedding_all_other_v2(self):
+ price = fc.numeric_column('price') # v2
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5) # v2
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10) # v2
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=3) # v2
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=3) # v2
+ shared_embedding_a, shared_embedding_b = fc_old.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2) # v1
+ all_cols = [
+ price, some_embedding_column, shared_embedding_a, shared_embedding_b
+ ]
+
+ with ops.Graph().as_default():
+ features = {
+ 'price': [[3.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ fc_old.linear_model(features, all_cols)
+ bias = get_linear_model_bias()
+ with _initialized_session():
+ self.assertAllClose([0.], bias.eval())
+
+ def test_linear_model_v1_shared_embedding_with_v2_cat_all_other_v2(self):
+ price = fc.numeric_column('price') # v2
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5) # v2
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10) # v2
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3) # v2
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3) # v2
+ shared_embedding_a, shared_embedding_b = fc_old.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2) # v1
+ all_cols = [
+ price, some_embedding_column, shared_embedding_a, shared_embedding_b
+ ]
+
+ with ops.Graph().as_default():
+ features = {
+ 'price': [[3.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ fc_old.linear_model(features, all_cols)
+ bias = get_linear_model_bias()
+ with _initialized_session():
+ self.assertAllClose([0.], bias.eval())
+
+ def test_linear_model_v1_v2_mix(self):
+ price = fc.numeric_column('price') # v2
+ some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5) # v1
+ some_embedding_column = fc_old.embedding_column(
+ some_sparse_column, dimension=10) # v1
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=3) # v2
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=3) # v2
+ shared_embedding_a, shared_embedding_b = fc_old.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2) # v1
+ all_cols = [
+ price, some_embedding_column, shared_embedding_a, shared_embedding_b
+ ]
+
+ with ops.Graph().as_default():
+ features = {
+ 'price': [[3.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ fc_old.linear_model(features, all_cols)
+ bias = get_linear_model_bias()
+ with _initialized_session():
+ self.assertAllClose([0.], bias.eval())
+
+ def test_linear_model_v2_shared_embedding_all_other_v1(self):
+ price = fc_old.numeric_column('price') # v1
+ some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5) # v1
+ some_embedding_column = fc_old.embedding_column(
+ some_sparse_column, dimension=10) # v1
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3) # v2
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3) # v2
+ shared_embedding_a, shared_embedding_b = fc.shared_embedding_columns_v2(
+ [categorical_column_a, categorical_column_b], dimension=2) # v2
+ all_cols = [
+ price, some_embedding_column, shared_embedding_a, shared_embedding_b
+ ]
+
+ with ops.Graph().as_default():
+ features = {
+ 'price': [[3.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ with self.assertRaisesRegexp(ValueError,
+ 'SharedEmbeddingColumns are not supported'):
+ fc_old.linear_model(features, all_cols)
+
+
class FeatureLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def test_retrieving_input(self):
features = {'a': [0.]}
- feature_layer = FeatureLayer(fc.numeric_column('a'))
+ feature_layer = fc.FeatureLayer(fc.numeric_column('a'))
inputs = self.evaluate(feature_layer(features))
self.assertAllClose([[0.]], inputs)
@@ -1770,7 +2941,7 @@ class FeatureLayerTest(test.TestCase):
dimension=embedding_dimension,
initializer=_embedding_column_initializer)
- feature_layer = FeatureLayer([embedding_column])
+ feature_layer = fc.FeatureLayer([embedding_column])
features = {'a': sparse_input}
inputs = feature_layer(features)
@@ -1815,7 +2986,7 @@ class FeatureLayerTest(test.TestCase):
dimension=embedding_dimension,
initializer=_embedding_column_initializer)
- feature_layer = FeatureLayer([embedding_column])
+ feature_layer = fc.FeatureLayer([embedding_column])
features = {'a': sparse_input}
def scale_matrix():
@@ -1837,11 +3008,11 @@ class FeatureLayerTest(test.TestCase):
def test_raises_if_empty_feature_columns(self):
with self.assertRaisesRegexp(ValueError,
'feature_columns must not be empty'):
- FeatureLayer(feature_columns=[])(features={})
+ fc.FeatureLayer(feature_columns=[])(features={})
def test_should_be_dense_column(self):
with self.assertRaisesRegexp(ValueError, 'must be a DenseColumn'):
- FeatureLayer(feature_columns=[
+ fc.FeatureLayer(feature_columns=[
fc.categorical_column_with_hash_bucket('wire_cast', 4)
])(
features={
@@ -1851,7 +3022,7 @@ class FeatureLayerTest(test.TestCase):
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegexp(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
- FeatureLayer(feature_columns={'a': fc.numeric_column('a')})(
+ fc.FeatureLayer(feature_columns={'a': fc.numeric_column('a')})(
features={
'a': [[0]]
})
@@ -1859,7 +3030,7 @@ class FeatureLayerTest(test.TestCase):
def test_bare_column(self):
with ops.Graph().as_default():
features = features = {'a': [0.]}
- net = FeatureLayer(fc.numeric_column('a'))(features)
+ net = fc.FeatureLayer(fc.numeric_column('a'))(features)
with _initialized_session():
self.assertAllClose([[0.]], net.eval())
@@ -1867,14 +3038,14 @@ class FeatureLayerTest(test.TestCase):
with ops.Graph().as_default():
features = features = {'a': [0.], 'b': [1.]}
columns = (fc.numeric_column(key) for key in features)
- net = FeatureLayer(columns)(features)
+ net = fc.FeatureLayer(columns)(features)
with _initialized_session():
self.assertAllClose([[0., 1.]], net.eval())
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
- FeatureLayer(
+ fc.FeatureLayer(
feature_columns=[fc.numeric_column('a'),
fc.numeric_column('a')])(
features={
@@ -1885,7 +3056,7 @@ class FeatureLayerTest(test.TestCase):
price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- net = FeatureLayer([price])(features)
+ net = fc.FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1.], [5.]], net.eval())
@@ -1893,7 +3064,7 @@ class FeatureLayerTest(test.TestCase):
price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- net = FeatureLayer([price])(features)
+ net = fc.FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
@@ -1905,7 +3076,7 @@ class FeatureLayerTest(test.TestCase):
'price1': [[1., 2.], [5., 6.]],
'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]]
}
- feature_layer = FeatureLayer([price1, price2])
+ feature_layer = fc.FeatureLayer([price1, price2])
self.assertEqual((None, 6), feature_layer.compute_output_shape((None,)))
net = feature_layer(features)
with _initialized_session():
@@ -1919,13 +3090,13 @@ class FeatureLayerTest(test.TestCase):
with self.assertRaisesRegexp(
Exception,
r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- FeatureLayer([price])(features)
+ fc.FeatureLayer([price])(features)
def test_reshaping(self):
price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default():
features = {'price': [[[1., 2.]], [[5., 6.]]]}
- net = FeatureLayer([price])(features)
+ net = fc.FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
@@ -1937,7 +3108,7 @@ class FeatureLayerTest(test.TestCase):
'price1': [[1., 2.], [5., 6.]],
'price2': [[3.], [4.]]
}
- net = FeatureLayer([price1, price2])(features)
+ net = fc.FeatureLayer([price1, price2])(features)
with _initialized_session():
self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
@@ -1947,7 +3118,7 @@ class FeatureLayerTest(test.TestCase):
with ops.Graph().as_default():
cols_dict = {}
features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- feature_layer = FeatureLayer([price1, price2])
+ feature_layer = fc.FeatureLayer([price1, price2])
net = feature_layer(features, cols_dict)
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], cols_dict[price1].eval())
@@ -1962,8 +3133,8 @@ class FeatureLayerTest(test.TestCase):
'price_a': [[1.]],
'price_b': [[3.]],
}
- net1 = FeatureLayer([price_a, price_b])(features)
- net2 = FeatureLayer([price_b, price_a])(features)
+ net1 = fc.FeatureLayer([price_a, price_b])(features)
+ net2 = fc.FeatureLayer([price_b, price_a])(features)
with _initialized_session():
self.assertAllClose([[1., 3.]], net1.eval())
self.assertAllClose([[1., 3.]], net2.eval())
@@ -1977,7 +3148,7 @@ class FeatureLayerTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
with self.assertRaisesRegexp(Exception, 'must be a DenseColumn'):
- FeatureLayer([animal])(features)
+ fc.FeatureLayer([animal])(features)
def test_static_batch_size_mismatch(self):
price1 = fc.numeric_column('price1')
@@ -1990,7 +3161,7 @@ class FeatureLayerTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- FeatureLayer([price1, price2])(features)
+ fc.FeatureLayer([price1, price2])(features)
def test_subset_of_static_batch_size_mismatch(self):
price1 = fc.numeric_column('price1')
@@ -2005,7 +3176,7 @@ class FeatureLayerTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- FeatureLayer([price1, price2, price3])(features)
+ fc.FeatureLayer([price1, price2, price3])(features)
def test_runtime_batch_size_mismatch(self):
price1 = fc.numeric_column('price1')
@@ -2015,7 +3186,7 @@ class FeatureLayerTest(test.TestCase):
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
'price2': [[3.], [4.]] # batchsize = 2
}
- net = FeatureLayer([price1, price2])(features)
+ net = fc.FeatureLayer([price1, price2])(features)
with _initialized_session() as sess:
with self.assertRaisesRegexp(errors.OpError,
'Dimensions of inputs should match'):
@@ -2029,7 +3200,7 @@ class FeatureLayerTest(test.TestCase):
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
}
- net = FeatureLayer([price1, price2])(features)
+ net = fc.FeatureLayer([price1, price2])(features)
with _initialized_session() as sess:
sess.run(
net,
@@ -2049,8 +3220,8 @@ class FeatureLayerTest(test.TestCase):
'sparse_feature': [['a'], ['x']],
}
all_cols = [some_embedding_column]
- FeatureLayer(all_cols)(features)
- FeatureLayer(all_cols)(features)
+ fc.FeatureLayer(all_cols)(features)
+ fc.FeatureLayer(all_cols)(features)
# Make sure that 2 variables get created in this case.
self.assertEqual(2, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
@@ -2088,10 +3259,10 @@ class FeatureLayerTest(test.TestCase):
dense_shape=(2, 2)),
}
all_cols = [embedding_column_a, embedding_column_b]
- FeatureLayer(
+ fc.FeatureLayer(
all_cols, shared_state_manager=shared_state_manager)(
features)
- FeatureLayer(
+ fc.FeatureLayer(
all_cols, shared_state_manager=shared_state_manager)(
features)
# Make sure that only 1 variable gets created in this case.
@@ -2127,7 +3298,7 @@ class FeatureLayerTest(test.TestCase):
values=(1, 2, 1),
dense_shape=(2, 2)),
}
- FeatureLayer(
+ fc.FeatureLayer(
all_cols, shared_state_manager=shared_state_manager1)(
features)
# Make sure that only 1 variable gets created in this case.
@@ -2150,7 +3321,7 @@ class FeatureLayerTest(test.TestCase):
dense_shape=(2, 2)),
}
- FeatureLayer(
+ fc.FeatureLayer(
all_cols, shared_state_manager=shared_state_manager2)(
features1)
# Make sure that only 1 variable gets created in this case.
@@ -2188,7 +3359,7 @@ class FeatureLayerTest(test.TestCase):
batch_size=2,
shuffle=False)
features = input_fn()
- net = FeatureLayer([price, one_hot_body_style, embedded_body_style])(
+ net = fc.FeatureLayer([price, one_hot_body_style, embedded_body_style])(
features)
self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
@@ -2243,7 +3414,8 @@ class FeatureLayerTest(test.TestCase):
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
self.assertEqual(1, features['country'].shape.ndims)
- net = FeatureLayer([price, one_hot_body_style, embedded_country])(features)
+ net = fc.FeatureLayer([price, one_hot_body_style, embedded_country])(
+ features)
self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
@@ -2296,7 +3468,8 @@ class FeatureLayerTest(test.TestCase):
dense_shape=(2,))
country_data = np.array([['US'], ['CA']])
- net = FeatureLayer([price, one_hot_body_style, embedded_country])(features)
+ net = fc.FeatureLayer([price, one_hot_body_style, embedded_country])(
+ features)
self.assertEqual(1 + 3 + 2, net.shape[1])
with _initialized_session() as sess:
@@ -2322,13 +3495,563 @@ class FeatureLayerTest(test.TestCase):
# Static rank 0 should fail
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- FeatureLayer([price])(features)
+ fc.FeatureLayer([price])(features)
+
+ # Dynamic rank 0 should fail
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ }
+ net = fc.FeatureLayer([price])(features)
+ self.assertEqual(1, net.shape[1])
+ with _initialized_session() as sess:
+ with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+ sess.run(net, feed_dict={features['price']: np.array(1)})
+
+
+class InputLayerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_retrieving_input(self):
+ features = {'a': [0.]}
+ input_layer = fc_old.InputLayer(fc.numeric_column('a'))
+ inputs = self.evaluate(input_layer(features))
+ self.assertAllClose([[0.]], inputs)
+
+ def test_reuses_variables(self):
+ with context.eager_mode():
+ sparse_input = sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (2, 0)),
+ values=(0, 1, 2),
+ dense_shape=(3, 3))
+
+ # Create feature columns (categorical and embedding).
+ categorical_column = fc.categorical_column_with_identity(
+ key='a', num_buckets=3)
+ embedding_dimension = 2
+
+ def _embedding_column_initializer(shape, dtype, partition_info):
+ del shape # unused
+ del dtype # unused
+ del partition_info # unused
+ embedding_values = (
+ (1, 0), # id 0
+ (0, 1), # id 1
+ (1, 1)) # id 2
+ return embedding_values
+
+ embedding_column = fc.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_embedding_column_initializer)
+
+ input_layer = fc_old.InputLayer([embedding_column])
+ features = {'a': sparse_input}
+
+ inputs = input_layer(features)
+ variables = input_layer.variables
+
+ # Sanity check: test that the inputs are correct.
+ self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)
+
+ # Check that only one variable was created.
+ self.assertEqual(1, len(variables))
+
+ # Check that invoking input_layer on the same features does not create
+ # additional variables
+ _ = input_layer(features)
+ self.assertEqual(1, len(variables))
+ self.assertEqual(variables[0], input_layer.variables[0])
+
+ def test_feature_column_input_layer_gradient(self):
+ with context.eager_mode():
+ sparse_input = sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (2, 0)),
+ values=(0, 1, 2),
+ dense_shape=(3, 3))
+
+ # Create feature columns (categorical and embedding).
+ categorical_column = fc.categorical_column_with_identity(
+ key='a', num_buckets=3)
+ embedding_dimension = 2
+
+ def _embedding_column_initializer(shape, dtype, partition_info):
+ del shape # unused
+ del dtype # unused
+ del partition_info # unused
+ embedding_values = (
+ (1, 0), # id 0
+ (0, 1), # id 1
+ (1, 1)) # id 2
+ return embedding_values
+
+ embedding_column = fc.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_embedding_column_initializer)
+
+ input_layer = fc_old.InputLayer([embedding_column])
+ features = {'a': sparse_input}
+
+ def scale_matrix():
+ matrix = input_layer(features)
+ return 2 * matrix
+
+ # Sanity check: Verify that scale_matrix returns the correct output.
+ self.assertAllEqual([[2, 0], [0, 2], [2, 2]], scale_matrix())
+
+ # Check that the returned gradient is correct.
+ grad_function = backprop.implicit_grad(scale_matrix)
+ grads_and_vars = grad_function()
+ indexed_slice = grads_and_vars[0][0]
+ gradient = grads_and_vars[0][0].values
+
+ self.assertAllEqual([0, 1, 2], indexed_slice.indices)
+ self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
+
+
+class FunctionalInputLayerTest(test.TestCase):
+
+ def test_raises_if_empty_feature_columns(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'feature_columns must not be empty'):
+ fc_old.input_layer(features={}, feature_columns=[])
+
+ def test_should_be_dense_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'):
+ fc_old.input_layer(
+ features={'a': [[0]]},
+ feature_columns=[
+ fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ ])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc_old.input_layer(
+ features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')})
+
+ def test_bare_column(self):
+ with ops.Graph().as_default():
+ features = features = {'a': [0.]}
+ net = fc_old.input_layer(features, fc.numeric_column('a'))
+ with _initialized_session():
+ self.assertAllClose([[0.]], net.eval())
+
+ def test_column_generator(self):
+ with ops.Graph().as_default():
+ features = features = {'a': [0.], 'b': [1.]}
+ columns = (fc.numeric_column(key) for key in features)
+ net = fc_old.input_layer(features, columns)
+ with _initialized_session():
+ self.assertAllClose([[0., 1.]], net.eval())
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ fc_old.input_layer(
+ features={'a': [[0]]},
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])
+
+ def test_one_column(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ net = fc_old.input_layer(features, [price])
+ with _initialized_session():
+ self.assertAllClose([[1.], [5.]], net.eval())
+
+ def test_multi_dimension(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ net = fc_old.input_layer(features, [price])
+ with _initialized_session():
+ self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ fc_old.input_layer(features, [price])
+
+ def test_reshaping(self):
+ price = fc.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': [[[1., 2.]], [[5., 6.]]]}
+ net = fc_old.input_layer(features, [price])
+ with _initialized_session():
+ self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
+
+ def test_multi_column(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ net = fc_old.input_layer(features, [price1, price2])
+ with _initialized_session():
+ self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
+
+ def test_fills_cols_to_vars(self):
+ # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
+ # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn
+ # creates a Variable.
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ cols_to_vars = {}
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+ fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
+ self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
+ self.assertEqual(0, len(cols_to_vars[price1]))
+ self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
+ self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
+ self.assertIsInstance(cols_to_vars[some_embedding_column][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
+
+ def test_fills_cols_to_vars_shared_embedding(self):
+ # Provide 5 DenseColumn's to input_layer: a NumericColumn, a
+ # BucketizedColumn, an EmbeddingColumn, two SharedEmbeddingColumns. The
+ # EmbeddingColumn creates a Variable and the two SharedEmbeddingColumns
+ # shared one variable.
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ shared_embedding_a, shared_embedding_b = fc_old.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ cols_to_vars = {}
+ all_cols = [
+ price1, dense_feature_bucketized, some_embedding_column,
+ shared_embedding_a, shared_embedding_b
+ ]
+ fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
+ self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
+ self.assertEqual(0, len(cols_to_vars[price1]))
+ self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
+ self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
+ self.assertEqual(1, len(cols_to_vars[shared_embedding_a]))
+ # This is a bug in the current implementation and should be fixed in the
+ # new one.
+ self.assertEqual(0, len(cols_to_vars[shared_embedding_b]))
+ self.assertIsInstance(cols_to_vars[some_embedding_column][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
+ self.assertIsInstance(cols_to_vars[shared_embedding_a][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[shared_embedding_a][0].shape, [3, 2])
+
+ def test_fills_cols_to_vars_partitioned_variables(self):
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ cols_to_vars = {}
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+ with variable_scope.variable_scope(
+ 'input_from_feature_columns',
+ partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
+ fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
+ self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
+ self.assertEqual(0, len(cols_to_vars[price1]))
+ self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
+ self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
+ self.assertEqual(
+ 'input_from_feature_columns/input_layer/sparse_feature_embedding/'
+ 'embedding_weights/part_0:0',
+ cols_to_vars[some_embedding_column][0].name)
+ self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10])
+ self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10])
+ self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10])
+
+ def test_column_order(self):
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
+ with ops.Graph().as_default():
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ }
+ net1 = fc_old.input_layer(features, [price_a, price_b])
+ net2 = fc_old.input_layer(features, [price_b, price_a])
+ with _initialized_session():
+ self.assertAllClose([[1., 3.]], net1.eval())
+ self.assertAllClose([[1., 3.]], net2.eval())
+
+ def test_fails_for_categorical_column(self):
+ animal = fc.categorical_column_with_identity('animal', num_buckets=4)
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+ with self.assertRaisesRegexp(Exception, 'must be a _DenseColumn'):
+ fc_old.input_layer(features, [animal])
+
+ def test_static_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1.], [5.], [7.]], # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc_old.input_layer(features, [price1, price2])
+
+ def test_subset_of_static_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]], # batchsize = 2
+ 'price3': [[3.], [4.], [5.]] # batchsize = 3
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc_old.input_layer(features, [price1, price2, price3])
+
+ def test_runtime_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ net = fc_old.input_layer(features, [price1, price2])
+ with _initialized_session() as sess:
+ with self.assertRaisesRegexp(errors.OpError,
+ 'Dimensions of inputs should match'):
+ sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+ def test_runtime_batch_size_matches(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ }
+ net = fc_old.input_layer(features, [price1, price2])
+ with _initialized_session() as sess:
+ sess.run(
+ net,
+ feed_dict={
+ features['price1']: [[1.], [5.]],
+ features['price2']: [[1.], [5.]],
+ })
+
+ def test_multiple_layers_with_same_embedding_column(self):
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+
+ with ops.Graph().as_default():
+ features = {
+ 'sparse_feature': [['a'], ['x']],
+ }
+ all_cols = [some_embedding_column]
+ fc_old.input_layer(features, all_cols)
+ fc_old.input_layer(features, all_cols)
+ # Make sure that 2 variables get created in this case.
+ self.assertEqual(2, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+ expected_var_names = [
+ 'input_layer/sparse_feature_embedding/embedding_weights:0',
+ 'input_layer_1/sparse_feature_embedding/embedding_weights:0'
+ ]
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+ def test_with_1d_sparse_tensor(self):
+ embedding_values = (
+ (1., 2., 3., 4., 5.), # id 0
+ (6., 7., 8., 9., 10.), # id 1
+ (11., 12., 13., 14., 15.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ del shape, dtype, partition_info
+ return embedding_values
+
+ # price has 1 dimension in input_layer
+ price = fc.numeric_column('price')
+
+ # one_hot_body_style has 3 dims in input_layer.
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ one_hot_body_style = fc.indicator_column(body_style)
+
+ # embedded_body_style has 5 dims in input_layer.
+ country = fc.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+ embedded_country = fc.embedding_column(
+ country, dimension=5, initializer=_initializer)
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price':
+ constant_op.constant([
+ 11.,
+ 12.,
+ ]),
+ 'body-style':
+ sparse_tensor.SparseTensor(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,)),
+ # This is dense tensor for the categorical_column.
+ 'country':
+ constant_op.constant(['CA', 'US']),
+ }
+ self.assertEqual(1, features['price'].shape.ndims)
+ self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+ self.assertEqual(1, features['country'].shape.ndims)
+
+ net = fc_old.input_layer(features,
+ [price, one_hot_body_style, embedded_country])
+ self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+
+ # Each row is formed by concatenating `embedded_body_style`,
+ # `one_hot_body_style`, and `price` in order.
+ self.assertAllEqual([[0., 0., 1., 11., 12., 13., 14., 15., 11.],
+ [1., 0., 0., 1., 2., 3., 4., 5., 12.]],
+ sess.run(net))
+
+ def test_with_1d_unknown_shape_sparse_tensor(self):
+ embedding_values = (
+ (1., 2.), # id 0
+ (6., 7.), # id 1
+ (11., 12.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ del shape, dtype, partition_info
+ return embedding_values
+
+ # price has 1 dimension in input_layer
+ price = fc.numeric_column('price')
+
+ # one_hot_body_style has 3 dims in input_layer.
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ one_hot_body_style = fc.indicator_column(body_style)
+
+ # embedded_body_style has 5 dims in input_layer.
+ country = fc.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+ embedded_country = fc.embedding_column(
+ country, dimension=2, initializer=_initializer)
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ 'body-style': array_ops.sparse_placeholder(dtypes.string),
+ # This is dense tensor for the categorical_column.
+ 'country': array_ops.placeholder(dtypes.string),
+ }
+ self.assertIsNone(features['price'].shape.ndims)
+ self.assertIsNone(features['body-style'].get_shape().ndims)
+ self.assertIsNone(features['country'].shape.ndims)
+
+ price_data = np.array([11., 12.])
+ body_style_data = sparse_tensor.SparseTensorValue(
+ indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
+ country_data = np.array([['US'], ['CA']])
+
+ net = fc_old.input_layer(features,
+ [price, one_hot_body_style, embedded_country])
+ self.assertEqual(1 + 3 + 2, net.shape[1])
+ with _initialized_session() as sess:
+
+ # Each row is formed by concatenating `embedded_body_style`,
+ # `one_hot_body_style`, and `price` in order.
+ self.assertAllEqual(
+ [[0., 0., 1., 1., 2., 11.], [1., 0., 0., 11., 12., 12.]],
+ sess.run(
+ net,
+ feed_dict={
+ features['price']: price_data,
+ features['body-style']: body_style_data,
+ features['country']: country_data
+ }))
+
+ def test_with_rank_0_feature(self):
+ # price has 1 dimension in input_layer
+ price = fc.numeric_column('price')
+ features = {
+ 'price': constant_op.constant(0),
+ }
+ self.assertEqual(0, features['price'].shape.ndims)
+
+ # Static rank 0 should fail
+ with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+ fc_old.input_layer(features, [price])
# Dynamic rank 0 should fail
features = {
'price': array_ops.placeholder(dtypes.float32),
}
- net = FeatureLayer([price])(features)
+ net = fc_old.input_layer(features, [price])
self.assertEqual(1, net.shape[1])
with _initialized_session() as sess:
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
@@ -2337,11 +4060,15 @@ class FeatureLayerTest(test.TestCase):
class MakeParseExampleSpecTest(test.TestCase):
- class _TestFeatureColumn(FeatureColumn,
+ class _TestFeatureColumn(fc.FeatureColumn,
collections.namedtuple('_TestFeatureColumn',
('parse_spec'))):
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
return '_TestFeatureColumn'
@@ -2458,6 +4185,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.string)
}, column.parse_example_spec)
+ self.assertTrue(column._is_v2_column)
def test_key_should_be_string(self):
with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
@@ -2501,7 +4229,10 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None)
+ column.get_sparse_tensors(
+ fc.FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
with self.cached_session():
lookup_ops.tables_initializer().run()
@@ -2525,7 +4256,10 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None)
+ column.get_sparse_tensors(
+ fc.FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
with self.cached_session():
lookup_ops.tables_initializer().run()
@@ -2564,7 +4298,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
@@ -2580,7 +4314,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
@@ -2616,7 +4350,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -2637,7 +4371,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -2659,7 +4393,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- id_tensor = _transform_features({'aaa': inputs}, [column], None)[column]
+ id_tensor = fc._transform_features({'aaa': inputs}, [column], None)[column]
with _initialized_session():
_assert_sparse_tensor_value(self,
sparse_tensor.SparseTensorValue(
@@ -2675,7 +4409,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=self._wire_vocabulary_size)
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': (('marlo', ''), ('skywalker', 'omar'))
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -2699,7 +4433,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -2723,7 +4457,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
values=('marlo', 'skywalker', 'omar', 'heisenberg'),
dense_shape=(2, 3))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -2749,7 +4483,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -2773,7 +4507,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
values=(11, 100, 30, 22),
dense_shape=(3, 3))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -2795,7 +4529,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dtype=dtypes.int32,
default_value=default_value)
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22))
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -2820,7 +4554,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
values=(11, 100, 30, 22),
dense_shape=(3, 3))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -2859,6 +4593,32 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
self.assertAllClose(((3.,), (5.,)), predictions.eval())
+ def test_old_linear_model(self):
+ wire_column = fc.categorical_column_with_vocabulary_file(
+ key='wire',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column.num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
class VocabularyListCategoricalColumnTest(test.TestCase):
@@ -2871,6 +4631,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.string)
}, column.parse_example_spec)
+ self.assertTrue(column._is_v2_column)
def test_key_should_be_string(self):
with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
@@ -2973,7 +4734,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
@@ -2987,7 +4748,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
@@ -3044,7 +4805,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3065,7 +4826,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- id_tensor = _transform_features({'aaa': inputs}, [column], None)[column]
+ id_tensor = fc._transform_features({'aaa': inputs}, [column], None)[column]
with _initialized_session():
_assert_sparse_tensor_value(
self,
@@ -3080,7 +4841,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
key='aaa',
vocabulary_list=('omar', 'stringer', 'marlo'))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': (('marlo', ''), ('skywalker', 'omar'))
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3103,7 +4864,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3126,7 +4887,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
values=('marlo', 'skywalker', 'omar', 'heisenberg'),
dense_shape=(2, 3))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3149,7 +4910,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
values=np.array((11, 100, 30, 22), dtype=np.int32),
dense_shape=(3, 3))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3170,10 +4931,10 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
dtype=dtypes.int32,
default_value=default_value)
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa':
- np.array(
- ((11, -1, -1), (100, 30, -1), (-1, -1, 22)), dtype=np.int32)
+ np.array(((11, -1, -1), (100, 30, -1), (-1, -1, 22)),
+ dtype=np.int32)
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
@@ -3196,7 +4957,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
values=(11, 100, 30, 22),
dense_shape=(3, 3))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3234,6 +4995,31 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
self.assertAllClose(((3.,), (5.,)), predictions.eval())
+ def test_old_linear_model(self):
+ wire_column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column.num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
class IdentityCategoricalColumnTest(test.TestCase):
@@ -3245,6 +5031,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, column.parse_example_spec)
+ self.assertTrue(column._is_v2_column)
def test_key_should_be_string(self):
with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
@@ -3285,7 +5072,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
@@ -3317,7 +5104,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
values=(0, 1, 0),
dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3336,7 +5123,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 1, 0),
dense_shape=(2, 2))
- id_tensor = _transform_features({'aaa': inputs}, [column], None)[column]
+ id_tensor = fc._transform_features({'aaa': inputs}, [column], None)[column]
with _initialized_session():
_assert_sparse_tensor_value(
self,
@@ -3349,7 +5136,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': ((0, -1), (1, 0))
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3369,7 +5156,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
values=(1, -1, 0),
dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3385,7 +5172,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
values=(1, 99, 0),
dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3402,7 +5189,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
values=(1, -1, 99),
dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3426,7 +5213,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
values=input_values,
dense_shape=input_shape)
id_weight_pair = column.get_sparse_tensors(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
@@ -3465,6 +5252,28 @@ class IdentityCategoricalColumnTest(test.TestCase):
# weight_var[2] + weight_var[1] = 3+2 = 5
self.assertAllClose(((1.,), (5.,)), predictions.eval())
+ def test_old_linear_model(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual(3, column.num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] = 1
+ # weight_var[2] + weight_var[1] = 3+2 = 5
+ self.assertAllClose(((1.,), (5.,)), predictions.eval())
+
class TransformFeaturesTest(test.TestCase):
@@ -3483,8 +5292,8 @@ class TransformFeaturesTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
}
- transformed = _transform_features(features,
- [bucketized_price, hashed_sparse], None)
+ transformed = fc._transform_features(
+ features, [bucketized_price, hashed_sparse], None)
with _initialized_session():
self.assertIn(bucketized_price.name, transformed[bucketized_price].name)
self.assertAllEqual([[0], [3]], transformed[bucketized_price].eval())
@@ -3494,12 +5303,16 @@ class TransformFeaturesTest(test.TestCase):
def test_column_order(self):
"""When the column is both dense and sparse, uses sparse tensors."""
- class _LoggerColumn(FeatureColumn):
+ class _LoggerColumn(fc.FeatureColumn):
def __init__(self, name):
self._name = name
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
return self._name
@@ -3516,12 +5329,12 @@ class TransformFeaturesTest(test.TestCase):
column1 = _LoggerColumn('1')
column2 = _LoggerColumn('2')
call_logger = {'count': 0}
- _transform_features({}, [column1, column2], None)
+ fc._transform_features({}, [column1, column2], None)
self.assertEqual(0, column1.call_order)
self.assertEqual(1, column2.call_order)
call_logger = {'count': 0}
- _transform_features({}, [column2, column1], None)
+ fc._transform_features({}, [column2, column1], None)
self.assertEqual(0, column1.call_order)
self.assertEqual(1, column2.call_order)
@@ -3534,17 +5347,19 @@ class IndicatorColumnTest(test.TestCase):
self.assertEqual(indicator_a.categorical_column.name, 'a')
self.assertEqual(indicator_a.name, 'a_indicator')
self.assertEqual(indicator_a.variable_shape, [1, 4])
+ self.assertTrue(indicator_a._is_v2_column)
- b = fc.categorical_column_with_hash_bucket('b', hash_bucket_size=100)
+ b = fc_old.categorical_column_with_hash_bucket('b', hash_bucket_size=100)
indicator_b = fc.indicator_column(b)
self.assertEqual(indicator_b.categorical_column.name, 'b')
self.assertEqual(indicator_b.name, 'b_indicator')
self.assertEqual(indicator_b.variable_shape, [1, 100])
+ self.assertFalse(indicator_b._is_v2_column)
def test_1D_shape_succeeds(self):
animal = fc.indicator_column(
fc.categorical_column_with_hash_bucket('animal', 4))
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'animal': ['fox', 'fox']
})
output = transformation_cache.get(animal, None)
@@ -3555,7 +5370,7 @@ class IndicatorColumnTest(test.TestCase):
# TODO(ispir/cassandrax): Swith to categorical_column_with_keys when ready.
animal = fc.indicator_column(
fc.categorical_column_with_hash_bucket('animal', 4))
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'animal':
sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0]],
@@ -3570,7 +5385,7 @@ class IndicatorColumnTest(test.TestCase):
animal = fc.indicator_column(
fc.categorical_column_with_identity('animal', num_buckets=4))
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'animal':
sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1]], values=[1, 1], dense_shape=[1, 2])
@@ -3582,7 +5397,7 @@ class IndicatorColumnTest(test.TestCase):
def test_multi_hot2(self):
animal = fc.indicator_column(
fc.categorical_column_with_identity('animal', num_buckets=4))
- transformation_cache = FeatureTransformationCache({
+ transformation_cache = fc.FeatureTransformationCache({
'animal':
sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
@@ -3632,8 +5447,8 @@ class IndicatorColumnTest(test.TestCase):
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
}
- indicator_tensor = _transform_features(features, [a_indicator],
- None)[a_indicator]
+ indicator_tensor = fc._transform_features(features, [a_indicator],
+ None)[a_indicator]
with _initialized_session():
self.assertAllEqual([[0, 0, 1], [1, 0, 0]], indicator_tensor.eval())
@@ -3647,8 +5462,8 @@ class IndicatorColumnTest(test.TestCase):
'ids': constant_op.constant([['c', 'b', 'a']]),
'weights': constant_op.constant([[2., 4., 6.]])
}
- indicator_tensor = _transform_features(features, [indicator],
- None)[indicator]
+ indicator_tensor = fc._transform_features(features, [indicator],
+ None)[indicator]
with _initialized_session():
self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval())
@@ -3662,8 +5477,8 @@ class IndicatorColumnTest(test.TestCase):
'ids': constant_op.constant([['c', 'b', 'unknown']]),
'weights': constant_op.constant([[2., 4., 6.]])
}
- indicator_tensor = _transform_features(features, [indicator],
- None)[indicator]
+ indicator_tensor = fc._transform_features(features, [indicator],
+ None)[indicator]
with _initialized_session():
self.assertAllEqual([[0., 4., 2.]], indicator_tensor.eval())
@@ -3675,8 +5490,8 @@ class IndicatorColumnTest(test.TestCase):
features = {
'ids': constant_op.constant([['c', 'b', 'unknown']]),
}
- indicator_tensor = _transform_features(features, [indicator],
- None)[indicator]
+ indicator_tensor = fc._transform_features(features, [indicator],
+ None)[indicator]
with _initialized_session():
self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
@@ -3700,6 +5515,44 @@ class IndicatorColumnTest(test.TestCase):
weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
self.assertAllClose([[2. + 3.]], predictions.eval())
+ def test_old_linear_model(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+
+ predictions = fc_old.linear_model(features, [animal])
+ weight_var = get_linear_model_column_var(animal)
+ with _initialized_session():
+ # All should be zero-initialized.
+ self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
+ self.assertAllClose([[0.]], predictions.eval())
+ weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
+ self.assertAllClose([[2. + 3.]], predictions.eval())
+
+ def test_old_linear_model_old_categorical(self):
+ animal = fc.indicator_column(
+ fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+
+ predictions = fc_old.linear_model(features, [animal])
+ weight_var = get_linear_model_column_var(animal)
+ with _initialized_session():
+ # All should be zero-initialized.
+ self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
+ self.assertAllClose([[0.]], predictions.eval())
+ weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
+ self.assertAllClose([[2. + 3.]], predictions.eval())
+
def test_feature_layer(self):
animal = fc.indicator_column(
fc.categorical_column_with_identity('animal', num_buckets=4))
@@ -3709,12 +5562,38 @@ class IndicatorColumnTest(test.TestCase):
sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- net = FeatureLayer([animal])(features)
+ net = fc.FeatureLayer([animal])(features)
with _initialized_session():
self.assertAllClose([[0., 1., 1., 0.]], net.eval())
+ def test_input_layer(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+ net = fc_old.input_layer(features, [animal])
+ with _initialized_session():
+ self.assertAllClose([[0., 1., 1., 0.]], net.eval())
-class _TestStateManager(StateManager):
+ def test_input_layer_old_categorical(self):
+ animal = fc.indicator_column(
+ fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+ net = fc_old.input_layer(features, [animal])
+ with _initialized_session():
+ self.assertAllClose([[0., 1., 1., 0.]], net.eval())
+
+
+class _TestStateManager(fc.StateManager):
def __init__(self, trainable=True):
# Dict of feature_column to a dict of variables.
@@ -3771,6 +5650,15 @@ class EmbeddingColumnTest(test.TestCase):
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, embedding_column.parse_example_spec)
+ self.assertTrue(embedding_column._is_v2_column)
+
+ def test_is_v2_column(self):
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension)
+ self.assertFalse(embedding_column._is_v2_column)
def test_all_constructor_args(self):
categorical_column = fc.categorical_column_with_identity(
@@ -3860,7 +5748,7 @@ class EmbeddingColumnTest(test.TestCase):
values=(0, 1, 0),
dense_shape=(2, 2))
}
- outputs = _transform_features(features, [a, a_embedded], None)
+ outputs = fc._transform_features(features, [a, a_embedded], None)
output_a = outputs[a]
output_embedded = outputs[a_embedded]
with _initialized_session():
@@ -3915,7 +5803,7 @@ class EmbeddingColumnTest(test.TestCase):
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': sparse_input
}), state_manager)
@@ -3927,6 +5815,66 @@ class EmbeddingColumnTest(test.TestCase):
self.assertAllEqual(embedding_values, global_vars[0].eval())
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+ def test_get_dense_tensor_old_categorical(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ # Provide sparse input and get dense result.
+ embedding_lookup = embedding_column._get_dense_tensor(
+ fc_old._LazyBuilder({
+ 'aaa': sparse_input
+ }))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+
def test_get_dense_tensor_3d(self):
# Inputs.
vocabulary_size = 4
@@ -3977,7 +5925,7 @@ class EmbeddingColumnTest(test.TestCase):
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': sparse_input
}), state_manager)
@@ -4040,7 +5988,7 @@ class EmbeddingColumnTest(test.TestCase):
input_values = array_ops.placeholder(dtype=dtypes.int64)
input_shape = array_ops.placeholder(dtype=dtypes.int64)
embedding_lookup = embedding_column.get_dense_tensor(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa':
sparse_tensor.SparseTensorValue(
indices=input_indices,
@@ -4108,7 +6056,7 @@ class EmbeddingColumnTest(test.TestCase):
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
- FeatureTransformationCache({
+ fc.FeatureTransformationCache({
'aaa': sparse_input
}), state_manager)
@@ -4241,7 +6189,7 @@ class EmbeddingColumnTest(test.TestCase):
initializer=_initializer)
# Provide sparse input and get dense result.
- l = FeatureLayer((embedding_column,))
+ l = fc.FeatureLayer((embedding_column,))
feature_layer = l({'aaa': sparse_input})
# Assert expected embedding variable and lookups.
@@ -4302,7 +6250,7 @@ class EmbeddingColumnTest(test.TestCase):
trainable=False)
# Provide sparse input and get dense result.
- feature_layer = FeatureLayer((embedding_column,))({'aaa': sparse_input})
+ feature_layer = fc.FeatureLayer((embedding_column,))({'aaa': sparse_input})
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
@@ -4314,6 +6262,220 @@ class EmbeddingColumnTest(test.TestCase):
self.assertAllEqual(embedding_values, global_vars[0].eval())
self.assertAllEqual(expected_lookups, feature_layer.eval())
+ def test_input_layer(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ # Provide sparse input and get dense result.
+ feature_layer = fc_old.input_layer({
+ 'aaa': sparse_input
+ }, (embedding_column,))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('input_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertItemsEqual(('input_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in trainable_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, trainable_vars[0].eval())
+ self.assertAllEqual(expected_lookups, feature_layer.eval())
+
+ def test_old_linear_model(self):
+ # Inputs.
+ batch_size = 4
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(batch_size, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ categorical_column.name: sparse_input
+ }, (embedding_column,))
+ expected_var_names = (
+ 'linear_model/bias_weights:0',
+ 'linear_model/aaa_embedding/weights:0',
+ 'linear_model/aaa_embedding/embedding_weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v
+ for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_embedding/embedding_weights:0']
+ linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # example 2, ids [], embedding[2] = [0, 0]
+ # example 3, ids [1], embedding[3] = [3, 5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
+ self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
+
+ def test_old_linear_model_old_categorical(self):
+ # Inputs.
+ batch_size = 4
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(batch_size, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ categorical_column.name: sparse_input
+ }, (embedding_column,))
+ expected_var_names = (
+ 'linear_model/bias_weights:0',
+ 'linear_model/aaa_embedding/weights:0',
+ 'linear_model/aaa_embedding/embedding_weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v
+ for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_embedding/embedding_weights:0']
+ linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # example 2, ids [], embedding[2] = [0, 0]
+ # example 3, ids [1], embedding[3] = [3, 5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
+ self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
+
class SharedEmbeddingColumnTest(test.TestCase):
@@ -4530,8 +6692,8 @@ class SharedEmbeddingColumnTest(test.TestCase):
values=(1, 2, 1),
dense_shape=(2, 2)),
}
- outputs = _transform_features(features, [a, a_embedded, b, b_embedded],
- None)
+ outputs = fc._transform_features(features, [a, a_embedded, b, b_embedded],
+ None)
output_a = outputs[a]
output_a_embedded = outputs[a_embedded]
output_b = outputs[b]
@@ -4599,9 +6761,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
# Provide sparse input and get dense result.
embedding_lookup_a = embedding_column_a.get_dense_tensor(
- FeatureTransformationCache(input_features), state_manager)
+ fc.FeatureTransformationCache(input_features), state_manager)
embedding_lookup_b = embedding_column_b.get_dense_tensor(
- FeatureTransformationCache(input_features), state_manager)
+ fc.FeatureTransformationCache(input_features), state_manager)
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
@@ -4665,9 +6827,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
# Provide sparse input and get dense result.
embedding_lookup_a = embedding_column_a.get_dense_tensor(
- FeatureTransformationCache(input_features), state_manager)
+ fc.FeatureTransformationCache(input_features), state_manager)
embedding_lookup_b = embedding_column_b.get_dense_tensor(
- FeatureTransformationCache(input_features), state_manager)
+ fc.FeatureTransformationCache(input_features), state_manager)
with _initialized_session() as sess:
sess.run([embedding_lookup_a, embedding_lookup_b], feed_dict=feed_dict)
@@ -4852,7 +7014,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
}
# Provide sparse input and get dense result.
- feature_layer = FeatureLayer(
+ feature_layer = fc.FeatureLayer(
feature_columns=(embedding_column_b, embedding_column_a,
embedding_column_c, embedding_column_d),
shared_state_manager=shared_state_manager)(
@@ -4946,6 +7108,14 @@ class WeightedCategoricalColumnTest(test.TestCase):
'ids': parsing_ops.VarLenFeature(dtypes.int64),
'values': parsing_ops.VarLenFeature(dtypes.float32)
}, column.parse_example_spec)
+ self.assertTrue(column._is_v2_column)
+
+ def test_is_v2_column(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ self.assertFalse(column._is_v2_column)
def test_deep_copy(self):
"""Tests deepcopy of categorical_column_with_hash_bucket."""
@@ -4987,7 +7157,10 @@ class WeightedCategoricalColumnTest(test.TestCase):
values=('omar', 'stringer', 'marlo'),
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'Bad dtype'):
- _transform_features({'ids': strings, 'values': strings}, (column,), None)
+ fc._transform_features({
+ 'ids': strings,
+ 'values': strings
+ }, (column,), None)
def test_column_name_collision(self):
with self.assertRaisesRegexp(ValueError, r'Parse config.*already exists'):
@@ -5007,7 +7180,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
with self.assertRaisesRegexp(
ValueError, 'values is not in features dictionary'):
- _transform_features({'ids': inputs}, (column,), None)
+ fc._transform_features({'ids': inputs}, (column,), None)
def test_parse_example(self):
a = fc.categorical_column_with_vocabulary_list(
@@ -5056,7 +7229,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(0.5, 1.0, 0.1),
dense_shape=(2, 2))
- id_tensor, weight_tensor = _transform_features({
+ id_tensor, weight_tensor = fc._transform_features({
'ids': inputs,
'values': weights,
}, (column,), None)[column]
@@ -5085,7 +7258,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(0.5, 1.0, 0.1),
dense_shape=(2, 2))
- id_tensor, weight_tensor = _transform_features({
+ id_tensor, weight_tensor = fc._transform_features({
'ids': ((0, -1), (1, 0)),
'values': weights,
}, (column,), None)[column]
@@ -5114,7 +7287,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 1, 0),
dense_shape=(2, 2))
- id_tensor, weight_tensor = _transform_features({
+ id_tensor, weight_tensor = fc._transform_features({
'ids': inputs,
'values': ((.5, 0.), (1., .1)),
}, (column,), None)[column]
@@ -5236,6 +7409,137 @@ class WeightedCategoricalColumnTest(test.TestCase):
# = 3*1 + 2*.1 = 3+.2 = 3.2
self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+ def test_old_linear_model(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(.5, 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ def test_old_linear_model_mismatched_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(ValueError,
+ r'Dimensions.*are not compatible'):
+ fc_old.linear_model({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (0, 1), (1, 0), (1, 1)),
+ values=(.5, 11., 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+
+ def test_old_linear_model_mismatched_dense_values(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,))
+ }, (column,),
+ sparse_combiner='mean')
+ # Disabling the constant folding optimizer here since it changes the
+ # error message differently on CPU and GPU.
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ with _initialized_session(config):
+ with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
+ predictions.eval()
+
+ def test_old_linear_model_mismatched_dense_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,), (.1,))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ def test_old_linear_model_old_categorical(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = fc_old.linear_model({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(.5, 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
# TODO(ptucker): Add test with embedding of weighted categorical.
if __name__ == '__main__':