aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-04-28 10:57:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 12:11:32 -0700
commit744112ccf1bd0628c74d1601becd0c00598e8aa3 (patch)
tree26bdb5212f8bb5f1cfcf0dc340a2be2924563aa6
parent13afb3d3f92ffe83fbe8d024e3a16a00863da8e0 (diff)
Started to graduate FeatureColumn from contrib to core.
Change: 154565789
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/feature_column/BUILD57
-rw-r--r--tensorflow/python/feature_column/feature_column.py972
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py800
5 files changed, 1831 insertions, 0 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index d059d227e6..a2f7a9fb63 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -307,6 +307,7 @@ filegroup(
"//tensorflow/python:all_files",
"//tensorflow/python/debug:all_files",
"//tensorflow/python/estimator:all_files",
+ "//tensorflow/python/feature_column:all_files",
"//tensorflow/python/kernel_tests:all_files",
"//tensorflow/python/kernel_tests/distributions:all_files",
"//tensorflow/python/ops/distributions:all_files",
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 28e1adb518..bcb837ad8d 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -80,6 +80,7 @@ py_library(
":weights_broadcast_ops",
"//third_party/py/numpy",
"//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/feature_column:feature_column",
"//tensorflow/python/ops/losses",
"//tensorflow/python/ops/distributions",
"//tensorflow/python/saved_model",
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
new file mode 100644
index 0000000000..d5eb20e997
--- /dev/null
+++ b/tensorflow/python/feature_column/BUILD
@@ -0,0 +1,57 @@
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ ],
+ features = [
+ "-layering_check",
+ "-parse_headers",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "feature_column",
+ srcs = ["feature_column.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_test(
+ name = "feature_column_test",
+ srcs = ["feature_column_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":feature_column",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
new file mode 100644
index 0000000000..7d8a42080d
--- /dev/null
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -0,0 +1,972 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This API defines FeatureColumn abstraction.
+
+FeatureColumns provide a high level abstraction for ingesting and representing
+features. FeatureColumns are also the primary way of encoding features for
+canned ${tf.estimator.Estimator}s.
+
+When using FeatureColumns with `Estimators`, the type of feature column you
+should choose depends on (1) the feature type and (2) the model type.
+
+(1) Feature type:
+
+ * Continuous features can be represented by `numeric_column`.
+ * Categorical features can be represented by any `categorical_column_with_*`
+ column:
+ - `categorical_column_with_keys`
+ - `categorical_column_with_vocabulary_file`
+ - `categorical_column_with_hash_bucket`
+ - `categorical_column_with_integerized_feature`
+
+(2) Model type:
+
+ * Deep neural network models (`DNNClassifier`, `DNNRegressor`).
+
+ Continuous features can be directly fed into deep neural network models.
+
+ age_column = numeric_column("age")
+
+ To feed sparse features into DNN models, wrap the column with
+ `embedding_column` or `indicator_column`. `indicator_column` is recommended
+ for features with only a few possible values. For features with many possible
+ values, `embedding_column` is recommended.
+
+ embedded_dept_column = embedding_column(
+ categorical_column_with_keys("department", ["math", "philosphy", ...]),
+ dimension=10)
+
+* Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
+
+ Sparse features can be fed directly into linear models. They behave like an
+ indicator column but with an efficient implementation.
+
+ dept_column = categorical_column_with_keys("department",
+ ["math", "philosophy", "english"])
+
+ It is recommended that continuous features be bucketized before being
+ fed into linear models.
+
+ bucketized_age_column = bucketized_column(
+ source_column=age_column,
+ boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
+
+ Sparse features can be crossed (also known as conjuncted or combined) in
+ order to form non-linearities, and then fed into linear models.
+
+ cross_dept_age_column = crossed_column(
+ columns=[department_column, bucketized_age_column],
+ hash_bucket_size=1000)
+
+Example of building canned `Estimator`s using FeatureColumns:
+
+ # Define features and transformations
+ deep_feature_columns = [age_column, embedded_dept_column]
+ wide_feature_columns = [dept_column, bucketized_age_column,
+ cross_dept_age_column]
+
+ # Build deep model
+ estimator = DNNClassifier(
+ feature_columns=deep_feature_columns,
+ hidden_units=[500, 250, 50])
+ estimator.train(...)
+
+ # Or build a wide model
+ estimator = LinearClassifier(
+ feature_columns=wide_feature_columns)
+ estimator.train(...)
+
+ # Or build a wide and deep model!
+ estimator = DNNLinearCombinedClassifier(
+ linear_feature_columns=wide_feature_columns,
+ dnn_feature_columns=deep_feature_columns,
+ dnn_hidden_units=[500, 250, 50])
+ estimator.train(...)
+
+
+FeatureColumns can also be transformed into a generic input layer for
+custom models using `input_from_feature_columns`.
+
+Example of building model using FeatureColumns, this can be used in a
+`model_fn` which is given to the {tf.estimator.Estimator}:
+
+ # Building model via layers
+
+ deep_feature_columns = [age_column, embedded_dept_column]
+ columns_to_tensor = parse_feature_columns_from_examples(
+ serialized=my_data,
+ feature_columns=deep_feature_columns)
+ first_layer = input_from_feature_columns(
+ columns_to_tensors=columns_to_tensor,
+ feature_columns=deep_feature_columns)
+ second_layer = fully_connected(first_layer, ...)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import nest
+
+
+def make_linear_model(features,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True):
+ """Returns a linear prediction `Tensor` based on given `feature_columns`.
+
+ This function generates a weighted sum for each unitss`. Weighted sum
+ refers to logits in classification problems. It refers to the prediction
+ itself for linear regression problems.
+
+ Main difference of `make_linear_model` and `make_input_layer` is handling of
+ categorical columns. `make_linear_model` treats them as `indicator_column`s
+ while `make_input_layer` explicitly requires wrapping each of them with an
+ `embedding_column` or an `indicator_column`.
+
+ Args:
+ features: A mapping from key to tensors. 'string' key means a base feature.
+ It can have `_FeatureColumn` as a key too. That means that FeatureColumn
+ is already transformed by the input pipeline.
+ feature_columns: An iterable containing all the FeatureColumns. All items
+ should be instances of classes derived from FeatureColumn.
+ units: units: An integer, dimensionality of the output space. Default
+ value is 1.
+ sparse_combiner: A string specifying how to reduce if a sparse column is
+ multivalent. Currently "mean", "sqrtn" and "sum" are supported, with "sum"
+ the default. "sqrtn" often achieves good accuracy, in particular with
+ bag-of-words columns. It combines each sparse columns independently.
+ * "sum": do not normalize features in the column
+ * "mean": do l1 normalization on features in the column
+ * "sqrtn": do l2 normalization on features in the column
+ weight_collections: A list of collection names to which the Variable will be
+ added. Note that, variables will also be added to collections
+ `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+
+ Returns:
+ A `Tensor` which represents predictions/logits of a linear model. Its shape
+ is (batch_size, units) and its dtype is `float32`.
+
+ Raises:
+ ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
+ nor `_CategoricalColumn`.
+ """
+ _check_feature_columns(feature_columns)
+ for column in feature_columns:
+ if not isinstance(column, (_DenseColumn, _CategoricalColumn)):
+ raise ValueError('Items of feature_columns must be either a _DenseColumn '
+ 'or _CategoricalColumn. Given: {}'.format(column))
+ weight_collections = list(weight_collections or [])
+ weight_collections += [
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES
+ ]
+ with variable_scope.variable_scope(
+ None, default_name='make_linear_model', values=features.values()):
+ weigthed_sums = []
+ builder = _LazyBuilder(features)
+ for column in sorted(feature_columns, key=lambda x: x.name):
+ with variable_scope.variable_scope(None, default_name=column.name):
+ if isinstance(column, _DenseColumn):
+ weigthed_sums.append(_create_dense_column_weighted_sum(
+ column, builder, units, weight_collections, trainable))
+ else:
+ weigthed_sums.append(_create_categorical_column_weighted_sum(
+ column, builder, units, sparse_combiner, weight_collections,
+ trainable))
+ predictions_no_bias = math_ops.add_n(
+ weigthed_sums, name='weighted_sum_no_bias')
+ bias = variable_scope.get_variable(
+ 'bias_weight',
+ shape=[units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
+ predictions = nn_ops.bias_add(
+ predictions_no_bias, bias, name='weighted_sum')
+
+ return predictions
+
+
+def numeric_column(key,
+ shape=(1,),
+ default_value=None,
+ dtype=dtypes.float32,
+ normalizer_fn=None):
+ """Represents real valued or numerical features.
+
+ An example:
+ ```python
+ price = numeric_column('price')
+ all_feature_columns = [price, ...]
+ dense_tensor = make_input_layer(features, all_feature_columns)
+
+ # or
+ bucketized_price = bucketized_column(price, boundaries=[...])
+ all_feature_columns = [bucketized_price, ...]
+ linear_prediction, _, _ = make_linear_model(features, all_feature_columns)
+
+ ```
+
+ Args:
+ key: A string providing key to look up corresponding `Tensor`.
+ shape: An iterable of integers specifies the shape of the `Tensor`. An
+ integer can be given which means a single dimension `Tensor` with given
+ width. The `Tensor` representing the column will have the shape of
+ [batch_size] + `shape`.
+ default_value: A single value compatible with `dtype` or an iterable of
+ values compatible with `dtype` which the column takes on during
+ `tf.Example` parsing if data is missing. A default value of `None` will
+ cause `tf.parse_example` to fail if an example does not contain this
+ column. If a single value is provided, the same value will be applied as
+ the default value for every item. If an iterable of values is provided,
+ the shape of the `default_value` should be equal to the given `shape`.
+ dtype: defines the type of values. Default value is `tf.float32`. Must be a
+ non-quantized, real integer or floating point type.
+ normalizer_fn: If not `None`, a function that can be used to normalize the
+ value of the tensor after `default_value` is applied for parsing.
+ Normalizer function takes the input `Tensor` as its argument, and returns
+ the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
+ even though most common use case of this function is normalization, it can
+ be used for any kind of Tensorflow transformations.
+
+ Returns:
+ A _NumericColumn.
+
+ Raises:
+ TypeError: if any dimension in shape is not an int
+ ValueError: if any dimension in shape is not a positive integer
+ TypeError: if `default_value` is an iterable but not compatible with `shape`
+ TypeError: if `default_value` is not compatible with `dtype`.
+ ValueError: if `dtype` is not convertible to `tf.float32`.
+ """
+ shape = _check_shape(shape, key)
+ if not (dtype.is_integer or dtype.is_floating):
+ raise ValueError('dtype must be convertible to float. '
+ 'dtype: {}, key: {}'.format(dtype, key))
+ default_value = _check_default_value(shape, default_value, dtype, key)
+
+ if normalizer_fn is not None and not callable(normalizer_fn):
+ raise TypeError(
+ 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
+
+ return _NumericColumn(
+ key,
+ shape=shape,
+ default_value=default_value,
+ dtype=dtype,
+ normalizer_fn=normalizer_fn)
+
+
+def categorical_column_with_hash_bucket(key,
+ hash_bucket_size,
+ dtype=dtypes.string):
+ """Represents sparse feature where ids are set by hashing.
+
+ Use this when your sparse features are in string or integer format where you
+ want to distribute your inputs into a finite number of buckets by hashing.
+ output_id = Hash(input_feature_string) % bucket_size
+
+ An example:
+ ```python
+ keywords = categorical_column_with_hash_bucket("keywords", 10K)
+ linear_prediction, _, _ = make_linear_model(features, all_feature_columns)
+ all_feature_columns = [keywords, ...]
+
+ # or
+ keywords_embedded = embedding_column(keywords, 16)
+ all_feature_columns = [keywords_embedded, ...]
+ dense_tensor = make_input_layer(features, all_feature_columns)
+ ```
+
+ Args:
+ key: A string providing key to look up corresponding `Tensor`.
+ hash_bucket_size: An int > 1. The number of buckets.
+ dtype: The type of features. Only string and integer types are supported.
+
+ Returns:
+ A `_CategoricalColumnHashed`.
+
+ Raises:
+ ValueError: `hash_bucket_size` is not greater than 1.
+ ValueError: `dtype` is neither string nor integer.
+ """
+ if hash_bucket_size is None:
+ raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key))
+
+ if hash_bucket_size < 1:
+ raise ValueError('hash_bucket_size must be at least 1. '
+ 'hash_bucket_size: {}, key: {}'.format(
+ hash_bucket_size, key))
+
+ if dtype != dtypes.string and not dtype.is_integer:
+ raise ValueError('dtype must be string or integer. '
+ 'dtype: {}, column_name: {}'.format(dtype, key))
+
+ return _CategoricalColumnHashed(key, hash_bucket_size, dtype)
+
+
+class _FeatureColumn(object):
+ """Represents a feature column abstraction.
+
+ WARNING: Do not subclass this layer unless you know what you are doing:
+ the API is subject to future changes.
+
+ To distinguish the concept of a feature family and a specific binary feature
+ within a family, we refer to a feature family like "country" as a feature
+ column. Following is an example feature in a `tf.Example` format:
+ {key: "country", value: [ "US" ]}
+ In this example the value of feature is "US" and "country" refers to the
+ column of the feature.
+
+ This class is an abstract class. User should not create instances of this.
+ """
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractproperty
+ def name(self):
+ """Returns string. used for variable_scope and naming."""
+ pass
+
+ @abc.abstractmethod
+ def _transform_feature(self, inputs):
+ """Returns transformed `Tensor`, uses `inputs` to access input tensors.
+
+ It uses `inputs` to get either raw feature or transformation of other
+ FeatureColumns.
+
+ Example input access:
+ Let's say a Feature column depends on raw feature ('raw') and another
+ `_FeatureColumn` (input_fc). To access corresponding Tensors, inputs will
+ be used as follows:
+
+ ```python
+ raw_tensor = inputs.get('raw')
+ fc_tensor = inputs.get(input_fc)
+ ```
+
+ Args:
+ inputs: A `_LazyBuilder` object to access inputs.
+
+ Returns:
+ Transformed feature `Tensor`.
+ """
+ pass
+
+ @abc.abstractproperty
+ def _parse_example_config(self):
+ """Returns a `tf.Example` parsing spec as dict.
+
+ It is used for get_parsing_spec for `tf.parse_example`. Returned spec is a
+ dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other
+ supported objects. Please check documentation of ${tf.parse_example} for all
+ supported spec objects.
+
+ Let's say a Feature column depends on raw feature ('raw') and another
+ `_FeatureColumn` (input_fc). One possible implementation of
+ _parse_example_config is as follows:
+
+ ```python
+ spec = {'raw': tf.FixedLenFeature(...)}
+ spec.update(input_fc._parse_example_config)
+ return spec
+ ```
+ """
+ pass
+
+
+class _DenseColumn(_FeatureColumn):
+ """Represents a column which can be represented as `Tensor`.
+
+ WARNING: Do not subclass this layer unless you know what you are doing:
+ the API is subject to future changes.
+
+ Some examples of this type are: numeric_column, embedding_column,
+ indicator_column.
+ """
+
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractproperty
+ def _variable_shape(self):
+ """Returns shape of variable which is compatible with _get_dense_tensor."""
+ pass
+
+ @abc.abstractmethod
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ """Returns a `Tensor`.
+
+ The output of this function will be used by model-buildier-functions. For
+ example the pseudo code of `make_input_layer` will be like that:
+ ```python
+ def make_input_layer(features, feature_columns, ...):
+ outputs = [fc._get_dense_tensor(...) for fc in feature_columns]
+ return tf.concat(outputs)
+ ```
+
+ Args:
+ inputs: A `_LazyBuilder` object to access inputs.
+ weight_collections: List of graph collections to which Variables (if any
+ will be created) are added.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see ${tf.Variable}).
+ """
+ pass
+
+
+def _create_dense_column_weighted_sum(
+ column, builder, units, weight_collections, trainable):
+ """Create a weighted sum of a dense column for make_linear_model."""
+ tensor = column._get_dense_tensor( # pylint: disable=protected-access
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ num_elements = tensor_shape.TensorShape(column._variable_shape).num_elements() # pylint: disable=protected-access
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ weight = variable_scope.get_variable(
+ name='weight',
+ shape=[num_elements, units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
+ return math_ops.matmul(tensor, weight, name='weighted_sum')
+
+
+class _CategoricalColumn(_FeatureColumn):
+ """Represents a categorical feautre.
+
+ WARNING: Do not subclass this layer unless you know what you are doing:
+ the API is subject to future changes.
+
+ A categorical feature typically handled with a ${tf.SparseTensor} of IDs.
+ """
+ __metaclass__ = abc.ABCMeta
+
+ IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name
+ 'IdWeightPair', ['id_tensor', 'weight_tensor'])
+
+ @abc.abstractproperty
+ def _num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ pass
+
+ @abc.abstractmethod
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Returns an IdWeightPair.
+
+ `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
+ weights.
+
+ `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
+ `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
+ `SparseTensor` of `float` or `None` to indicate all weights should be
+ taken to be 1. If specified, `weight_tensor` must have exactly the same
+ shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
+ output of a `VarLenFeature` which is a ragged matrix.
+
+ Args:
+ inputs: A `LazyBuilder` as a cache to get input tensors required to
+ create `IdWeightPair`.
+ weight_collections: List of graph collections to which variables (if any
+ will be created) are added.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see ${tf.get_variable}).
+ """
+ pass
+
+
+def _create_categorical_column_weighted_sum(
+ column, builder, units, sparse_combiner, weight_collections, trainable):
+ """Create a weighted sum of a categorical column for make_linear_model."""
+ sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ weight = variable_scope.get_variable(
+ name='weight',
+ shape=[column._num_buckets, units], # pylint: disable=protected-access
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
+ return _safe_embedding_lookup_sparse(
+ weight,
+ sparse_tensors.id_tensor,
+ sparse_weights=sparse_tensors.weight_tensor,
+ combiner=sparse_combiner,
+ name='weighted_sum')
+
+
+class _LazyBuilder(object):
+ """Handles caching of transformations while building the model.
+
+ `FeatureColumn` specifies how to digest an input column to the network. Some
+ feature columns require data transformations. This class caches those
+ transformations.
+
+ Some features may be used in more than one place. For example, one can use a
+ bucketized feature by itself and a cross with it. In that case we
+ should create only one bucketization op instead of creating ops for each
+ feature column separately. To handle re-use of transformed columns,
+ `_LazyBuilder` caches all previously transformed columns.
+
+ Example:
+ We're trying to use the following `FeatureColumns`:
+
+ ```python
+ bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...)
+ keywords = fc.categorical_column_with_hash_buckets("keywords", ...)
+ age_X_keywords = fc.crossed_column([bucketized_age, keywords])
+ ... = make_linear_model(features,
+ [bucketized_age, keywords, age_X_keywords]
+ ```
+
+ If we transform each column independently, then we'll get duplication of
+ bucketization (one for cross, one for bucketization itself).
+ The `_LazyBuilder` eliminates this duplication.
+ """
+
+ def __init__(self, features):
+ """Creates a `_LazyBuilder`.
+
+ Args:
+ features: A mapping from feature column to tensors. A `string` key
+ signifies a base feature (not-transformed). A `FeatureColumn` key
+ means that this `Tensor` is the output of an existing `FeatureColumn`
+ which can be reused.
+ """
+ self._columns_to_tensors = features.copy()
+
+ def get(self, key):
+ """Returns a `Tensor` for the given key.
+
+ A `str` key is used to access a base feature (not-transformed). When a
+ `_FeatureColumn` is passed, the transformed feature is returned if it
+ already exists, otherwise the given `_FeatureColumn` is asked to provide its
+ transformed output, which is then cached.
+
+ Args:
+ key: a `str` or a `_FeatureColumn`.
+
+ Returns:
+ The transformed `Tensor` corresponding to the `key`.
+
+ Raises:
+ ValueError: if key is not found or a transformed `Tensor` cannot be
+ computed.
+ """
+ if key in self._columns_to_tensors:
+ # Feature_column is already transformed or it's a raw feature.
+ return self._columns_to_tensors[key]
+
+ if not isinstance(key, (str, _FeatureColumn)):
+ raise TypeError('"key" must be either a "str" or "_FeatureColumn". '
+ 'Provided: {}'.format(key))
+
+ if not isinstance(key, _FeatureColumn):
+ raise ValueError('Feature {} is not in features dictionary.'.format(key))
+
+ column = key
+ logging.debug('Transforming feature_column %s.', column)
+ transformed = column._transform_feature(self) # pylint: disable=protected-access
+ if transformed is None:
+ raise ValueError('Column {} is not supported.'.format(column.name))
+ self._columns_to_tensors[column] = transformed
+ return self._columns_to_tensors[column]
+
+
+def _check_feature_columns(feature_columns):
+ if isinstance(feature_columns, dict):
+ raise ValueError('Expected feature_columns to be iterable, found dict.')
+ for column in feature_columns:
+ if not isinstance(column, _FeatureColumn):
+ raise ValueError('Items of feature_columns must be a _FeatureColumn.')
+ name_to_column = dict()
+ for column in feature_columns:
+ if column.name in name_to_column:
+ raise ValueError('Duplicate feature column name found for columns: {} '
+ 'and {}. This usually means that these columns refer to '
+ 'same base feature. Either one must be discarded or a '
+ 'duplicated but renamed item must be inserted in '
+ 'features dict.'.format(column,
+ name_to_column[column.name]))
+ name_to_column[column.name] = column
+
+
+class _NumericColumn(_DenseColumn,
+ collections.namedtuple('_NumericColumn', [
+ 'key', 'shape', 'default_value', 'dtype',
+ 'normalizer_fn'
+ ])):
+ """see `numeric_column`."""
+
+ @property
+ def name(self):
+ return self.key
+
+ @property
+ def _parse_example_config(self):
+ return {
+ self.key:
+ parsing_ops.FixedLenFeature(self.shape, self.dtype,
+ self.default_value)
+ }
+
+ def _transform_feature(self, inputs):
+ input_tensor = inputs.get(self.key)
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ raise ValueError(
+ 'The corresponding Tensor of numerical column must be a Tensor. '
+ 'SparseTensor is not supported. key: {}'.format(self.key))
+ if self.normalizer_fn is not None:
+ input_tensor = self.normalizer_fn(input_tensor)
+ return math_ops.to_float(input_tensor)
+
+ @property
+ def _variable_shape(self):
+ return self.shape
+
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ return inputs.get(self)
+
+
+def _create_tuple(shape, value):
+ """Returns a tuple with given shape and filled with value."""
+ if shape:
+ return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])])
+ return value
+
+
+def _as_tuple(value):
+ if not nest.is_sequence(value):
+ return value
+ return tuple([_as_tuple(v) for v in value])
+
+
+def _check_shape(shape, key):
+ """Returns shape if it's valid, raises error otherwise."""
+ assert shape is not None
+ if not nest.is_sequence(shape):
+ shape = [shape]
+ shape = tuple(shape)
+ for dimension in shape:
+ if not isinstance(dimension, int):
+ raise TypeError('shape dimensions must be integer. '
+ 'shape: {}, key: {}'.format(shape, key))
+ if dimension < 1:
+ raise ValueError('shape dimensions must be greater than 0. '
+ 'shape: {}, key: {}'.format(shape, key))
+ return shape
+
+
+def _is_shape_and_default_value_compatible(default_value, shape):
+ """Verifies compatibility of shape and default_value."""
+ # Invalid condition:
+ # * if default_value is not a scalar and shape is empty
+ # * or if default_value is an iterable and shape is not empty
+ if nest.is_sequence(default_value) != bool(shape):
+ return False
+ if not shape:
+ return True
+ if len(default_value) != shape[0]:
+ return False
+ for i in range(shape[0]):
+ if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]):
+ return False
+ return True
+
+
+def _check_default_value(shape, default_value, dtype, key):
+ """Returns default value as tuple if it's valid, otherwise raises errors.
+
+ This function verifies that `default_value` is compatible with both `shape`
+ and `dtype`. If it is not compatible, it raises an error. If it is compatible,
+ it casts default_value to a tuple and returns it. `key` is used only
+ for error message.
+
+ Args:
+ shape: An iterable of integers specifies the shape of the `Tensor`.
+ default_value: If a single value is provided, the same value will be applied
+ as the default value for every item. If an iterable of values is
+ provided, the shape of the `default_value` should be equal to the given
+ `shape`.
+ dtype: defines the type of values. Default value is `tf.float32`. Must be a
+ non-quantized, real integer or floating point type.
+ key: A string providing key to look up corresponding `Tensor`.
+
+ Returns:
+ A tuple which will be used as default value.
+
+ Raises:
+ TypeError: if `default_value` is an iterable but not compatible with `shape`
+ TypeError: if `default_value` is not compatible with `dtype`.
+ ValueError: if `dtype` is not convertible to `tf.float32`.
+ """
+ if default_value is None:
+ return None
+
+ if isinstance(default_value, int):
+ return _create_tuple(shape, default_value)
+
+ if isinstance(default_value, float) and dtype.is_floating:
+ return _create_tuple(shape, default_value)
+
+ if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays
+ default_value = default_value.tolist()
+
+ if nest.is_sequence(default_value):
+ if not _is_shape_and_default_value_compatible(default_value, shape):
+ raise ValueError(
+ 'The shape of default_value must be equal to given shape. '
+ 'default_value: {}, shape: {}, key: {}'.format(
+ default_value, shape, key))
+ # Check if the values in the list are all integers or are convertible to
+ # floats.
+ is_list_all_int = all(
+ isinstance(v, int) for v in nest.flatten(default_value))
+ is_list_has_float = any(
+ isinstance(v, float) for v in nest.flatten(default_value))
+ if is_list_all_int:
+ return _as_tuple(default_value)
+ if is_list_has_float and dtype.is_floating:
+ return _as_tuple(default_value)
+ raise TypeError('default_value must be compatible with dtype. '
+ 'default_value: {}, dtype: {}, key: {}'.format(
+ default_value, dtype, key))
+
+
+class _CategoricalColumnHashed(
+ _CategoricalColumn,
+ collections.namedtuple('_CategoricalColumnHashed',
+ ['key', 'hash_bucket_size', 'dtype'])):
+ """see `categorical_column_with_hash_bucket`."""
+
+ @property
+ def name(self):
+ return self.key
+
+ @property
+ def _parse_example_config(self):
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+ def _transform_feature(self, inputs):
+ input_tensor = inputs.get(self.key)
+ if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ raise ValueError('SparseColumn input must be a SparseTensor.')
+
+ if (input_tensor.dtype != dtypes.string and
+ not input_tensor.dtype.is_integer):
+ raise ValueError('input tensors dtype must be string or integer. '
+ 'dtype: {}, column_name: {}'.format(
+ input_tensor.dtype, self.key))
+
+ if self.dtype.is_integer != input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Column dtype and SparseTensors dtype must be compatible. '
+ 'key: {}, column dtype: {}, tensor dtype: {}'.format(
+ self.key, self.dtype, input_tensor.dtype))
+
+ if self.dtype == dtypes.string:
+ sparse_values = input_tensor.values
+ else:
+ sparse_values = string_ops.as_string(input_tensor.values)
+
+ sparse_id_values = string_ops.string_to_hash_bucket_fast(
+ sparse_values, self.hash_bucket_size, name='lookup')
+ return sparse_tensor_lib.SparseTensor(
+ input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
+
+ @property
+ def _num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.hash_bucket_size
+
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
+
+# TODO(zakaria): Move this to embedding_ops and make it public.
+def _safe_embedding_lookup_sparse(embedding_weights,
+ sparse_ids,
+ sparse_weights=None,
+ combiner=None,
+ default_id=None,
+ name=None,
+ partition_strategy='div',
+ max_norm=None):
+ """Lookup embedding results, accounting for invalid IDs and empty features.
+
+ The partitioned embedding in `embedding_weights` must all be the same shape
+ except for the first dimension. The first dimension is allowed to vary as the
+ vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
+ may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
+ partitioner.
+
+ Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
+ with non-positive weight. For an entry with no features, the embedding vector
+ for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
+
+ The ids and weights may be multi-dimensional. Embeddings are always aggregated
+ along the last dimension.
+
+ Args:
+ embedding_weights: A list of `P` float tensors or values representing
+ partitioned embedding tensors. Alternatively, a `PartitionedVariable`,
+ created by partitioning along dimension 0. The total unpartitioned
+ shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
+ vocab size and `e_1, ..., e_m` are the embedding dimensions.
+ sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
+ ids. `d_0` is typically batch size.
+ sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
+ float weights corresponding to `sparse_ids`, or `None` if all weights
+ are be assumed to be 1.0.
+ combiner: A string specifying how to combine embedding results for each
+ entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
+ the default.
+ default_id: The id to use for an entry with no features.
+ name: A name for this operation (optional).
+ partition_strategy: A string specifying the partitioning strategy.
+ Currently `"div"` and `"mod"` are supported. Default is `"div"`.
+ max_norm: If not None, all embeddings are l2-normalized to max_norm before
+ combining.
+
+
+ Returns:
+ Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
+
+ Raises:
+ ValueError: if `embedding_weights` is empty.
+ """
+ if combiner is None:
+ logging.warn('The default value of combiner will change from \"mean\" '
+ 'to \"sqrtn\" after 2016/11/01.')
+ combiner = 'mean'
+ if embedding_weights is None:
+ raise ValueError('Missing embedding_weights %s.' % embedding_weights)
+ if isinstance(embedding_weights, variables.PartitionedVariable):
+ embedding_weights = list(embedding_weights) # get underlying Variables.
+ if not isinstance(embedding_weights, list):
+ embedding_weights = [embedding_weights]
+ if len(embedding_weights) < 1:
+ raise ValueError('Missing embedding_weights %s.' % embedding_weights)
+
+ dtype = sparse_weights.dtype if sparse_weights is not None else None
+ if isinstance(embedding_weights, variables.PartitionedVariable):
+ embedding_weights = list(embedding_weights)
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
+
+ with ops.name_scope(name, 'embedding_lookup',
+ embedding_weights + [sparse_ids,
+ sparse_weights]) as scope:
+ # Reshape higher-rank sparse ids and weights to linear segment ids.
+ original_shape = sparse_ids.dense_shape
+ original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
+ original_rank = (
+ array_ops.size(original_shape)
+ if original_rank_dim.value is None
+ else original_rank_dim.value)
+ sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
+ math_ops.reduce_prod(
+ array_ops.slice(original_shape, [0], [original_rank - 1])),
+ array_ops.gather(original_shape, original_rank - 1)])
+ if sparse_weights is not None:
+ sparse_weights = sparse_tensor_lib.SparseTensor(
+ sparse_ids.indices,
+ sparse_weights.values, sparse_ids.dense_shape)
+
+ # Prune invalid ids and weights.
+ sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
+
+ # Fill in dummy values for empty features, if necessary.
+ sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
+ default_id or
+ 0)
+ if sparse_weights is not None:
+ sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
+
+ result = embedding_ops.embedding_lookup_sparse(
+ embedding_weights,
+ sparse_ids,
+ sparse_weights,
+ combiner=combiner,
+ partition_strategy=partition_strategy,
+ name=None if default_id is None else scope,
+ max_norm=max_norm)
+
+ if default_id is None:
+ # Broadcast is_row_empty to the same shape as embedding_lookup_result,
+ # for use in Select.
+ is_row_empty = array_ops.tile(
+ array_ops.reshape(is_row_empty, [-1, 1]),
+ array_ops.stack([1, array_ops.shape(result)[1]]))
+
+ result = array_ops.where(is_row_empty,
+ array_ops.zeros_like(result),
+ result,
+ name=scope)
+
+ # Reshape back from linear ids back into higher-dimensional dense result.
+ final_result = array_ops.reshape(
+ result,
+ array_ops.concat([
+ array_ops.slice(
+ math_ops.cast(original_shape, dtypes.int32), [0],
+ [original_rank - 1]),
+ array_ops.slice(array_ops.shape(result), [1], [-1])
+ ], 0))
+ final_result.set_shape(tensor_shape.unknown_shape(
+ (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
+ return final_result
+
+
+def _prune_invalid_ids(sparse_ids, sparse_weights):
+ """Prune invalid IDs (< 0) from the input ids and weights."""
+ is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
+ if sparse_weights is not None:
+ is_id_valid = math_ops.logical_and(
+ is_id_valid, math_ops.greater(sparse_weights.values, 0))
+ sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
+ if sparse_weights is not None:
+ sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
+ return sparse_ids, sparse_weights
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
new file mode 100644
index 0000000000..eefe3b0297
--- /dev/null
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -0,0 +1,800 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for feature_column."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import numpy as np
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.client import session
+from tensorflow.python.feature_column import feature_column as fc
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.platform import test
+
+
+def _initialized_session():
+ sess = session.Session()
+ sess.run(variables_lib.global_variables_initializer())
+ sess.run(data_flow_ops.tables_initializer())
+ return sess
+
+
+class LazyColumnTest(test.TestCase):
+
+ def test_transormations_called_once(self):
+
+ class TransformCounter(fc._FeatureColumn):
+
+ def __init__(self):
+ self.num_transform = 0
+
+ @property
+ def name(self):
+ return 'TransformCounter'
+
+ def _transform_feature(self, cache):
+ self.num_transform += 1 # Count transform calls.
+ return cache.get('a')
+
+ @property
+ def _parse_example_config(self):
+ pass
+
+ builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])})
+ column = TransformCounter()
+ self.assertEqual(0, column.num_transform)
+ builder.get(column)
+ self.assertEqual(1, column.num_transform)
+ builder.get(column)
+ self.assertEqual(1, column.num_transform)
+
+ def test_returns_transform_output(self):
+
+ class Transformer(fc._FeatureColumn):
+
+ @property
+ def name(self):
+ return 'Transformer'
+
+ def _transform_feature(self, cache):
+ return 'Output'
+
+ @property
+ def _parse_example_config(self):
+ pass
+
+ builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])})
+ column = Transformer()
+ self.assertEqual('Output', builder.get(column))
+ self.assertEqual('Output', builder.get(column))
+
+ def test_does_not_pollute_given_features_dict(self):
+
+ class Transformer(fc._FeatureColumn):
+
+ @property
+ def name(self):
+ return 'Transformer'
+
+ def _transform_feature(self, cache):
+ return 'Output'
+
+ @property
+ def _parse_example_config(self):
+ pass
+
+ features = {'a': constant_op.constant([[2], [3.]])}
+ builder = fc._LazyBuilder(features=features)
+ builder.get(Transformer())
+ self.assertEqual(['a'], list(features.keys()))
+
+ def test_error_if_feature_is_not_found(self):
+ builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])})
+ with self.assertRaisesRegexp(ValueError,
+ 'bbb is not in features dictionary'):
+ builder.get('bbb')
+
+ def test_not_supported_feature_column(self):
+
+ class NotAProperColumn(fc._FeatureColumn):
+
+ @property
+ def name(self):
+ return 'NotAProperColumn'
+
+ def _transform_feature(self, cache):
+ # It should return not None.
+ pass
+
+ @property
+ def _parse_example_config(self):
+ pass
+
+ builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])})
+ with self.assertRaisesRegexp(ValueError,
+ 'NotAProperColumn is not supported'):
+ builder.get(NotAProperColumn())
+
+ def test_key_should_be_string_or_feature_colum(self):
+
+ class NotAFeatureColumn(object):
+ pass
+
+ builder = fc._LazyBuilder(features={'a': constant_op.constant([[2], [3.]])})
+ with self.assertRaisesRegexp(
+ TypeError, '"key" must be either a "str" or "_FeatureColumn".'):
+ builder.get(NotAFeatureColumn())
+
+
+class NumericalColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ a = fc.numeric_column('aaa')
+ self.assertEqual('aaa', a.key)
+ self.assertEqual((1,), a.shape)
+ self.assertIsNone(a.default_value)
+ self.assertEqual(dtypes.float32, a.dtype)
+ self.assertIsNone(a.normalizer_fn)
+
+ def test_shape_saved_as_tuple(self):
+ a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
+ self.assertEqual((1, 2), a.shape)
+
+ def test_default_value_saved_as_tuple(self):
+ a = fc.numeric_column('aaa', default_value=4.)
+ self.assertEqual((4.,), a.default_value)
+ a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
+ self.assertEqual(((3., 2.),), a.default_value)
+
+ def test_shape_and_default_value_compatibility(self):
+ fc.numeric_column('aaa', shape=[2], default_value=[1, 2.])
+ with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
+ fc.numeric_column('aaa', shape=[2], default_value=[1, 2, 3.])
+ fc.numeric_column(
+ 'aaa', shape=[3, 2], default_value=[[2, 3], [1, 2], [2, 3.]])
+ with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
+ fc.numeric_column(
+ 'aaa', shape=[3, 1], default_value=[[2, 3], [1, 2], [2, 3.]])
+ with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
+ fc.numeric_column(
+ 'aaa', shape=[3, 3], default_value=[[2, 3], [1, 2], [2, 3.]])
+
+ def test_default_value_type_check(self):
+ fc.numeric_column(
+ 'aaa', shape=[2], default_value=[1, 2.], dtype=dtypes.float32)
+ fc.numeric_column(
+ 'aaa', shape=[2], default_value=[1, 2], dtype=dtypes.int32)
+ with self.assertRaisesRegexp(TypeError, 'must be compatible with dtype'):
+ fc.numeric_column(
+ 'aaa', shape=[2], default_value=[1, 2.], dtype=dtypes.int32)
+ with self.assertRaisesRegexp(TypeError,
+ 'default_value must be compatible with dtype'):
+ fc.numeric_column('aaa', default_value=['string'])
+
+ def test_shape_must_be_positive_integer(self):
+ with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'):
+ fc.numeric_column(
+ 'aaa', shape=[
+ 1.0,
+ ])
+
+ with self.assertRaisesRegexp(ValueError,
+ 'shape dimensions must be greater than 0'):
+ fc.numeric_column(
+ 'aaa', shape=[
+ 0,
+ ])
+
+ def test_dtype_is_convertable_to_float(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'dtype must be convertible to float'):
+ fc.numeric_column('aaa', dtype=dtypes.string)
+
+ def test_scalar_deafult_value_fills_the_shape(self):
+ a = fc.numeric_column('aaa', shape=[2, 3], default_value=2.)
+ self.assertEqual(((2., 2., 2.), (2., 2., 2.)), a.default_value)
+
+ def test_parse_config(self):
+ a = fc.numeric_column('aaa', shape=[2, 3], dtype=dtypes.int32)
+ self.assertEqual({
+ 'aaa': parsing_ops.FixedLenFeature((2, 3), dtype=dtypes.int32)
+ }, a._parse_example_config)
+
+ def test_parse_example_no_default_value(self):
+ price = fc.numeric_column('price', shape=[2])
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'price':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=price._parse_example_config)
+ self.assertIn('price', features)
+ with self.test_session():
+ self.assertAllEqual([[20., 110.]], features['price'].eval())
+
+ def test_parse_example_with_default_value(self):
+ price = fc.numeric_column('price', shape=[2], default_value=11.)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'price':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ no_data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'something_else':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString(),
+ no_data.SerializeToString()],
+ features=price._parse_example_config)
+ self.assertIn('price', features)
+ with self.test_session():
+ self.assertAllEqual([[20., 110.], [11., 11.]], features['price'].eval())
+
+ def test_normalizer_fn_must_be_callable(self):
+ with self.assertRaisesRegexp(TypeError, 'must be a callable'):
+ fc.numeric_column('price', normalizer_fn='NotACallable')
+
+ def test_normalizer_fn_transform_feature(self):
+
+ def _increment_two(input_tensor):
+ return input_tensor + 2.
+
+ price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
+ builder = fc._LazyBuilder({
+ 'price': constant_op.constant([[1., 2.], [5., 6.]])
+ })
+ output = builder.get(price)
+ with self.test_session():
+ self.assertAllEqual([[3., 4.], [7., 8.]], output.eval())
+
+ def test_get_dense_tensor(self):
+
+ def _increment_two(input_tensor):
+ return input_tensor + 2.
+
+ price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
+ builder = fc._LazyBuilder({
+ 'price': constant_op.constant([[1., 2.], [5., 6.]])
+ })
+ self.assertEqual(builder.get(price), price._get_dense_tensor(builder))
+
+ def test_sparse_tensor_not_supported(self):
+ price = fc.numeric_column('price')
+ builder = fc._LazyBuilder({
+ 'price':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
+ })
+ with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
+ price._transform_feature(builder)
+
+ def test_deep_copy(self):
+ a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3., 2.]])
+ a_copy = copy.deepcopy(a)
+ self.assertEqual(a_copy.name, 'aaa')
+ self.assertEqual(a_copy.shape, (1, 2))
+ self.assertEqual(a_copy.default_value, ((3., 2.),))
+
+ def test_numpy_default_value(self):
+ a = fc.numeric_column(
+ 'aaa', shape=[1, 2], default_value=np.array([[3., 2.]]))
+ self.assertEqual(a.default_value, ((3., 2.),))
+
+ def test_make_linear_model(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': constant_op.constant([[1.], [5.]])}
+ predictions = fc.make_linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[10.], [50.]], predictions.eval())
+
+
+class SparseColumnHashedTest(test.TestCase):
+
+ def test_defaults(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10)
+ self.assertEqual('aaa', a.name)
+ self.assertEqual('aaa', a.key)
+ self.assertEqual(10, a.hash_bucket_size)
+ self.assertEqual(dtypes.string, a.dtype)
+
+ def test_bucket_size_should_be_given(self):
+ with self.assertRaisesRegexp(ValueError, 'hash_bucket_size must be set.'):
+ fc.categorical_column_with_hash_bucket('aaa', None)
+
+ def test_bucket_size_should_be_positive(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'hash_bucket_size must be at least 1'):
+ fc.categorical_column_with_hash_bucket('aaa', 0)
+
+ def test_dtype_should_be_string_or_integer(self):
+ fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.string)
+ fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32)
+
+ def test_deep_copy(self):
+ """Tests deepcopy of categorical_column_with_hash_bucket."""
+ column = fc.categorical_column_with_hash_bucket('aaa', 10)
+ column_copy = copy.deepcopy(column)
+ self.assertEqual('aaa', column_copy.name)
+ self.assertEqual(10, column_copy.hash_bucket_size)
+ self.assertEqual(dtypes.string, column_copy.dtype)
+
+ def test_parse_config(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.string)
+ }, a._parse_example_config)
+
+ def test_parse_config_int(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, a._parse_example_config)
+
+ def test_strings_should_be_hashed(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ builder = fc._LazyBuilder({'wire': wire_tensor})
+ output = builder.get(hashed_sparse)
+ # Check exact hashed output. If hashing changes this test will break.
+ expected_values = [6, 4, 1]
+ with self.test_session():
+ self.assertEqual(dtypes.int64, output.values.dtype)
+ self.assertAllEqual(expected_values, output.values.eval())
+ self.assertAllEqual(wire_tensor.indices.eval(), output.indices.eval())
+ self.assertAllEqual(wire_tensor.dense_shape.eval(),
+ output.dense_shape.eval())
+
+ def test_tensor_dtype_should_be_string_or_integer(self):
+ string_fc = fc.categorical_column_with_hash_bucket(
+ 'a_string', 10, dtype=dtypes.string)
+ int_fc = fc.categorical_column_with_hash_bucket(
+ 'a_int', 10, dtype=dtypes.int32)
+ float_fc = fc.categorical_column_with_hash_bucket(
+ 'a_float', 10, dtype=dtypes.string)
+ int_tensor = sparse_tensor.SparseTensor(
+ values=constant_op.constant([101]),
+ indices=[[0, 0]],
+ dense_shape=[1, 1])
+ string_tensor = sparse_tensor.SparseTensor(
+ values=constant_op.constant(['101']),
+ indices=[[0, 0]],
+ dense_shape=[1, 1])
+ float_tensor = sparse_tensor.SparseTensor(
+ values=constant_op.constant([101.]),
+ indices=[[0, 0]],
+ dense_shape=[1, 1])
+ builder = fc._LazyBuilder({
+ 'a_int': int_tensor,
+ 'a_string': string_tensor,
+ 'a_float': float_tensor
+ })
+ builder.get(string_fc)
+ builder.get(int_fc)
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ builder.get(float_fc)
+
+ def test_dtype_should_match_with_tensor(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket(
+ 'wire', 10, dtype=dtypes.int64)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ builder = fc._LazyBuilder({'wire': wire_tensor})
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ builder.get(hashed_sparse)
+
+ def test_ints_should_be_hashed(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket(
+ 'wire', 10, dtype=dtypes.int64)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=[101, 201, 301],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ builder = fc._LazyBuilder({'wire': wire_tensor})
+ output = builder.get(hashed_sparse)
+ # Check exact hashed output. If hashing changes this test will break.
+ expected_values = [3, 7, 5]
+ with self.test_session():
+ self.assertAllEqual(expected_values, output.values.eval())
+
+ def test_int32_64_is_compatible(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket(
+ 'wire', 10, dtype=dtypes.int64)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=constant_op.constant([101, 201, 301], dtype=dtypes.int32),
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ builder = fc._LazyBuilder({'wire': wire_tensor})
+ output = builder.get(hashed_sparse)
+ # Check exact hashed output. If hashing changes this test will break.
+ expected_values = [3, 7, 5]
+ with self.test_session():
+ self.assertAllEqual(expected_values, output.values.eval())
+
+ def test_get_sparse_tensors(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ builder = fc._LazyBuilder({'wire': wire_tensor})
+ self.assertEqual(
+ builder.get(hashed_sparse),
+ hashed_sparse._get_sparse_tensors(builder).id_tensor)
+
+
+def get_linear_model_bias():
+ with variable_scope.variable_scope('make_linear_model', reuse=True):
+ return variable_scope.get_variable('bias_weight')
+
+
+def get_linear_model_column_var(column):
+ return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ 'make_linear_model/' + column.name)[0]
+
+
+class MakeLinearModelTest(test.TestCase):
+
+ def test_should_be_feature_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
+ fc.make_linear_model(
+ features={'a': [[0]]}, feature_columns='NotSupported')
+
+ def test_should_be_dense_or_categorical_column(self):
+
+ class NotSupportedColumn(fc._FeatureColumn):
+
+ @property
+ def name(self):
+ return 'NotSupportedColumn'
+
+ def _transform_feature(self, cache):
+ pass
+
+ @property
+ def _parse_example_config(self):
+ pass
+
+ with self.assertRaisesRegexp(
+ ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
+ fc.make_linear_model(
+ features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc.make_linear_model(
+ features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')})
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ fc.make_linear_model(
+ features={'a': [[0]]},
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])
+
+ def test_dense_bias(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': constant_op.constant([[1.], [5.]])}
+ predictions = fc.make_linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ sess.run(price_var.assign([[10.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions.eval())
+
+ def test_sparse_bias(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.make_linear_model(features, [wire_cast])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_and_sparse_bias(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
+ predictions = fc.make_linear_model(features, [wire_cast, price])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[1015.], [10065.]], predictions.eval())
+
+ def test_dense_multi_output(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': constant_op.constant([[1.], [5.]])}
+ predictions = fc.make_linear_model(features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0., 0., 0.], bias.eval())
+ self.assertAllClose([[0., 0., 0.]], price_var.eval())
+ sess.run(price_var.assign([[10., 100., 1000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
+ predictions.eval())
+
+ def test_sparse_multi_output(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.make_linear_model(features, [wire_cast], units=3)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose([0., 0., 0.], bias.eval())
+ self.assertAllClose([[0.] * 3] * 4, wire_cast_var.eval())
+ sess.run(
+ wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.], [
+ 1000., 1100., 1200.
+ ], [10000., 11000., 12000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
+ predictions.eval())
+
+ def test_dense_multi_dimension(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': constant_op.constant([[1., 2.], [5., 6.]])}
+ predictions = fc.make_linear_model(features, [price])
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_sparse_combiner(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.make_linear_model(
+ features, [wire_cast], sparse_combiner='mean')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [5010.]], predictions.eval())
+
+ def test_dense_multi_dimension_multi_output(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': constant_op.constant([[1., 2.], [5., 6.]])}
+ predictions = fc.make_linear_model(features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0., 0., 0.], bias.eval())
+ self.assertAllClose([[0., 0., 0.], [0., 0., 0.]], price_var.eval())
+ sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
+ sess.run(bias.assign([2., 3., 4.]))
+ self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
+ predictions.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': constant_op.constant([[1.], [5.]])}
+ predictions = fc.make_linear_model(features, [price])
+ with _initialized_session():
+ with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
+ predictions.eval()
+
+ def test_dense_reshaping(self):
+ price = fc.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': constant_op.constant([[[1., 2.]], [[5., 6.]]])}
+ predictions = fc.make_linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_dense_multi_column(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': constant_op.constant([[1., 2.], [5., 6.]]),
+ 'price2': constant_op.constant([[3.], [4.]])
+ }
+ predictions = fc.make_linear_model(features, [price1, price2])
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price1_var.eval())
+ self.assertAllClose([[0.]], price2_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price1_var.assign([[10.], [100.]]))
+ sess.run(price2_var.assign([[1000.]]))
+ sess.run(bias.assign([7.]))
+ self.assertAllClose([[3217.], [4657.]], predictions.eval())
+
+ def test_dense_collection(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': constant_op.constant([[1.], [5.]])}
+ fc.make_linear_model(features, [price], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ self.assertIn(bias, my_vars)
+ self.assertIn(price_var, my_vars)
+
+ def test_sparse_collection(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc.make_linear_model(
+ features, [wire_cast], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, my_vars)
+ self.assertIn(wire_cast_var, my_vars)
+
+ def test_dense_trainable_default(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': constant_op.constant([[1.], [5.]])}
+ fc.make_linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(price_var, trainable_vars)
+
+ def test_sparse_trainable_default(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc.make_linear_model(features, [wire_cast])
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(wire_cast_var, trainable_vars)
+
+ def test_dense_trainable_false(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': constant_op.constant([[1.], [5.]])}
+ fc.make_linear_model(features, [price], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_sparse_trainable_false(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc.make_linear_model(features, [wire_cast], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_column_order(self):
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ fc.make_linear_model(
+ features, [price_a, wire_cast, price_b],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ fc.make_linear_model(
+ features, [wire_cast, price_b, price_a],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+
+if __name__ == '__main__':
+ test.main()