aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer
diff options
context:
space:
mode:
authorGravatar Petros Mol <pmol@google.com>2017-03-16 10:02:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-16 11:26:32 -0700
commitf5398a32b7a3e5d8249a1d03016eab2168dd324d (patch)
tree0a807b05a12d5d5ab5319a15ee2dc85dfd3e68c5 /tensorflow/contrib/linear_optimizer
parentc15c8e766b280fcaa0ed09617842762116c8fe4d (diff)
Adding SDCA-based tf.learn estimators.
Change: 150340985
Diffstat (limited to 'tensorflow/contrib/linear_optimizer')
-rw-r--r--tensorflow/contrib/linear_optimizer/BUILD32
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_estimator.py567
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py502
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py20
4 files changed, 1121 insertions, 0 deletions
diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD
index fbd7959c39..d87066f6f6 100644
--- a/tensorflow/contrib/linear_optimizer/BUILD
+++ b/tensorflow/contrib/linear_optimizer/BUILD
@@ -104,6 +104,38 @@ py_test(
],
)
+py_library(
+ name = "sdca_estimator_py",
+ srcs = ["python/sdca_estimator.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":sdca_ops_py",
+ ":sparse_feature_column_py",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/contrib/learn",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_test(
+ name = "sdca_estimator_test",
+ srcs = ["python/sdca_estimator_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":sdca_estimator_py",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py
new file mode 100644
index 0000000000..b6074c856e
--- /dev/null
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py
@@ -0,0 +1,567 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Linear Estimators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import layers
+from tensorflow.contrib.framework.python.ops import variables as contrib_variables
+from tensorflow.contrib.learn.python.learn.estimators import estimator
+from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.contrib.learn.python.learn.estimators import prediction_key
+from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
+from tensorflow.contrib.linear_optimizer.python.ops import sdca_ops
+from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import session_run_hook
+
+
+def _head_is_valid_for_sdca(head):
+ """Returns true if the provided head is supported by SDCAOptimizer."""
+ # pylint: disable=protected-access
+ return isinstance(head, head_lib._BinaryLogisticHead) or isinstance(
+ head, head_lib._BinarySvmHead) or isinstance(head,
+ head_lib._RegressionHead)
+ # pylint: enable=protected-access
+
+
+def _add_bias_column(feature_columns, columns_to_tensors, bias_variable,
+ columns_to_variables):
+ """Adds a fake bias feature column filled with all 1s."""
+ # TODO(b/31008490): Move definition to a common constants place.
+ bias_column_name = "tf_virtual_bias_column"
+ if any(col.name is bias_column_name for col in feature_columns):
+ raise ValueError("%s is a reserved column name." % bias_column_name)
+ if not feature_columns:
+ raise ValueError("feature_columns can't be empty.")
+
+ # Loop through input tensors until we can figure out batch_size.
+ batch_size = None
+ for column in columns_to_tensors.values():
+ if isinstance(column, tuple):
+ column = column[0]
+ if isinstance(column, sparse_tensor.SparseTensor):
+ shape = tensor_util.constant_value(column.dense_shape)
+ if shape is not None:
+ batch_size = shape[0]
+ break
+ else:
+ batch_size = array_ops.shape(column)[0]
+ break
+ if batch_size is None:
+ raise ValueError("Could not infer batch size from input features.")
+
+ bias_column = layers.real_valued_column(bias_column_name)
+ columns_to_tensors[bias_column] = array_ops.ones(
+ [batch_size, 1], dtype=dtypes.float32)
+ columns_to_variables[bias_column] = [bias_variable]
+
+
+def _get_sdca_train_step(optimizer, columns_to_variables, weight_column_name,
+ loss_type, features, targets, global_step):
+ """Returns the training operation of an SdcaModel optimizer."""
+
+ def _dense_tensor_to_sparse_feature_column(dense_tensor):
+ """Returns SparseFeatureColumn for the input dense_tensor."""
+ ignore_value = 0.0
+ sparse_indices = array_ops.where(
+ math_ops.not_equal(dense_tensor,
+ math_ops.cast(ignore_value, dense_tensor.dtype)))
+ sparse_values = array_ops.gather_nd(dense_tensor, sparse_indices)
+ # TODO(sibyl-Aix6ihai, sibyl-vie3Poto): Makes this efficient, as now SDCA supports
+ # very sparse features with weights and not weights.
+ return SparseFeatureColumn(
+ array_ops.reshape(
+ array_ops.split(value=sparse_indices, num_or_size_splits=2,
+ axis=1)[0], [-1]),
+ array_ops.reshape(
+ array_ops.split(value=sparse_indices, num_or_size_splits=2,
+ axis=1)[1], [-1]),
+ array_ops.reshape(math_ops.to_float(sparse_values), [-1]))
+
+ def _training_examples_and_variables():
+ """Returns dictionaries for training examples and variables."""
+ batch_size = targets.get_shape()[0]
+
+ # Iterate over all feature columns and create appropriate lists for dense
+ # and sparse features as well as dense and sparse weights (variables) for
+ # SDCA.
+ # TODO(sibyl-vie3Poto): Reshape variables stored as values in column_to_variables
+ # dict as 1-dimensional tensors.
+ dense_features, sparse_features, sparse_feature_with_values = [], [], []
+ dense_feature_weights = []
+ sparse_feature_weights, sparse_feature_with_values_weights = [], []
+ for column in sorted(columns_to_variables.keys(), key=lambda x: x.key):
+ transformed_tensor = features[column]
+ if isinstance(column, layers.feature_column._RealValuedColumn): # pylint: disable=protected-access
+ # A real-valued column corresponds to a dense feature in SDCA. A
+ # transformed tensor corresponding to a RealValuedColumn has rank 2
+ # (its shape is typically [batch_size, column.dimension]) and so it
+ # can be passed to SDCA as is.
+ dense_features.append(transformed_tensor)
+ # For real valued columns, the variables list contains exactly one
+ # element.
+ dense_feature_weights.append(columns_to_variables[column][0])
+ elif isinstance(column, layers.feature_column._BucketizedColumn): # pylint: disable=protected-access
+ # A bucketized column corresponds to a sparse feature in SDCA. The
+ # bucketized feature is "sparsified" for SDCA by converting it to a
+ # SparseFeatureColumn respresenting the one-hot encoding of the
+ # bucketized feature.
+ #
+ # TODO(sibyl-vie3Poto): Explore whether it is more efficient to translate a
+ # bucketized feature column to a dense feature in SDCA. This will likely
+ # depend on the number of buckets.
+ dense_bucket_tensor = column._to_dnn_input_layer(transformed_tensor) # pylint: disable=protected-access
+ sparse_feature_column = _dense_tensor_to_sparse_feature_column(
+ dense_bucket_tensor)
+ sparse_feature_with_values.append(sparse_feature_column)
+ # For bucketized columns, the variables list contains exactly one
+ # element.
+ sparse_feature_with_values_weights.append(
+ columns_to_variables[column][0])
+ elif isinstance(
+ column,
+ (
+ layers.feature_column._CrossedColumn, # pylint: disable=protected-access
+ layers.feature_column._SparseColumn)): # pylint: disable=protected-access
+ sparse_features.append(
+ SparseFeatureColumn(
+ array_ops.reshape(
+ array_ops.split(
+ value=transformed_tensor.indices,
+ num_or_size_splits=2,
+ axis=1)[0], [-1]),
+ array_ops.reshape(transformed_tensor.values, [-1]), None))
+ sparse_feature_weights.append(columns_to_variables[column][0])
+ elif isinstance(column, layers.feature_column._WeightedSparseColumn): # pylint: disable=protected-access
+ id_tensor = column.id_tensor(transformed_tensor)
+ weight_tensor = column.weight_tensor(transformed_tensor)
+ sparse_feature_with_values.append(
+ SparseFeatureColumn(
+ array_ops.reshape(
+ array_ops.split(
+ value=id_tensor.indices, num_or_size_splits=2, axis=1)[
+ 0], [-1]),
+ array_ops.reshape(id_tensor.values, [-1]),
+ array_ops.reshape(weight_tensor.values, [-1])))
+ sparse_feature_with_values_weights.append(
+ columns_to_variables[column][0])
+ else:
+ raise ValueError("SDCAOptimizer does not support column type {}".format(
+ type(column).__name__))
+
+ example_weights = array_ops.reshape(
+ features[weight_column_name],
+ shape=[-1]) if weight_column_name else array_ops.ones([batch_size])
+ example_ids = features[optimizer.example_id_column]
+ sparse_feature_with_values.extend(sparse_features)
+ sparse_feature_with_values_weights.extend(sparse_feature_weights)
+ examples = dict(
+ sparse_features=sparse_feature_with_values,
+ dense_features=dense_features,
+ example_labels=math_ops.to_float(
+ array_ops.reshape(targets, shape=[-1])),
+ example_weights=example_weights,
+ example_ids=example_ids)
+ sdca_variables = dict(
+ sparse_features_weights=sparse_feature_with_values_weights,
+ dense_features_weights=dense_feature_weights)
+ return examples, sdca_variables
+
+ training_examples, training_variables = _training_examples_and_variables()
+ sdca_model = sdca_ops.SdcaModel(
+ examples=training_examples,
+ variables=training_variables,
+ options=dict(
+ symmetric_l1_regularization=optimizer.symmetric_l1_regularization,
+ symmetric_l2_regularization=optimizer.symmetric_l2_regularization,
+ num_loss_partitions=optimizer.num_loss_partitions,
+ num_table_shards=optimizer.num_table_shards,
+ loss_type=loss_type))
+ train_op = sdca_model.minimize(global_step=global_step)
+ return sdca_model, train_op
+
+
+def sdca_model_fn(features, labels, mode, params, config=None):
+ """A model_fn for linear models that use the SDCA optimizer.
+
+ Args:
+ features: A dict of `Tensor` keyed by column name.
+ labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
+ dtype `int32` or `int64` with values in the set {0, 1}.
+ mode: Defines whether this is training, evaluation or prediction.
+ See `ModeKeys`.
+ params: A dict of hyperparameters.
+ The following hyperparameters are expected:
+ * head: A `Head` instance. Type must be one of `_BinarySvmHead`,
+ `_RegressionHead` or `_BinaryLogisticHead`.
+ * feature_columns: An iterable containing all the feature columns used by
+ the model.
+ * optimizer: An `SDCAOptimizer` instance.
+ * weight_column_name: A string defining the weight feature column, or
+ None if there are no weights.
+ * update_weights_hook: A `SessionRunHook` object or None. Used to update
+ model weights.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Returns:
+ A `ModelFnOps` instance.
+
+ Raises:
+ ValueError: If `optimizer` is not an `SDCAOptimizer` instance.
+ ValueError: If the type of head is neither `_BinarySvmHead`, nor
+ `_RegressionHead` nor `_MultiClassHead`.
+ ValueError: If mode is not any of the `ModeKeys`.
+ """
+ head = params["head"]
+ feature_columns = params["feature_columns"]
+ example_id_column = params["example_id_column"]
+ l1_regularization = params["l1_regularization"]
+ l2_regularization = params["l2_regularization"]
+ num_loss_partitions = params["num_loss_partitions"]
+ weight_column_name = params["weight_column_name"]
+ update_weights_hook = params.get("update_weights_hook", None)
+
+ loss_type = None
+ if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access
+ loss_type = "hinge_loss"
+ elif isinstance(head, head_lib._BinaryLogisticHead): # pylint: disable=protected-access
+ loss_type = "logistic_loss"
+ elif isinstance(head, head_lib._RegressionHead): # pylint: disable=protected-access
+ loss_type = "squared_loss"
+ else:
+ raise ValueError("Unsupported head type: {}".format(type(head)))
+
+ assert head.logits_dimension == 1, (
+ "SDCA only applies to logits_dimension=1.")
+
+ # Update num_loss_partitions based on number of workers.
+ n_loss_partitions = num_loss_partitions or max(1, config.num_worker_replicas)
+ optimizer = sdca_optimizer.SDCAOptimizer(
+ example_id_column=example_id_column,
+ num_loss_partitions=n_loss_partitions,
+ symmetric_l1_regularization=l1_regularization,
+ symmetric_l2_regularization=l2_regularization)
+
+ parent_scope = "linear"
+
+ with variable_scope.variable_op_scope(features.values(),
+ parent_scope) as scope:
+ logits, columns_to_variables, bias = (
+ layers.weighted_sum_from_feature_columns(
+ columns_to_tensors=features,
+ feature_columns=feature_columns,
+ num_outputs=1,
+ scope=scope))
+
+ _add_bias_column(feature_columns, features, bias, columns_to_variables)
+
+ def _train_op_fn(unused_loss):
+ global_step = contrib_variables.get_global_step()
+ sdca_model, train_op = _get_sdca_train_step(optimizer, columns_to_variables,
+ weight_column_name, loss_type,
+ features, labels, global_step)
+ if update_weights_hook is not None:
+ update_weights_hook.set_parameters(sdca_model, train_op)
+ return train_op
+
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ labels=labels,
+ mode=mode,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ if update_weights_hook is not None:
+ return model_fn_ops._replace(training_chief_hooks=(
+ model_fn_ops.training_chief_hooks + [update_weights_hook]))
+ return model_fn_ops
+
+
+class _SdcaUpdateWeightsHook(session_run_hook.SessionRunHook):
+ """SessionRunHook to update and shrink SDCA model weights."""
+
+ def __init__(self):
+ pass
+
+ def set_parameters(self, sdca_model, train_op):
+ self._sdca_model = sdca_model
+ self._train_op = train_op
+
+ def begin(self):
+ """Construct the update_weights op.
+
+ The op is implicitly added to the default graph.
+ """
+ self._update_op = self._sdca_model.update_weights(self._train_op)
+
+ def before_run(self, run_context):
+ """Return the update_weights op so that it is executed during this run."""
+ return session_run_hook.SessionRunArgs(self._update_op)
+
+
+class _SDCAEstimator(estimator.Estimator):
+ """Base estimator class for linear models using the SDCA optimizer.
+
+ This class should not be used directly. Rather, users should call one of the
+ derived estimators.
+
+ The input_fn provided to `fit`, `evaluate` and predict_* methods should have
+ the following features, otherwise there will be a `KeyError`:
+ - a feature with `key=example_id_column` whose value is a `Tensor` of dtype
+ string.
+ - if `weight_column_name` is not `None`, a feature with
+ `key=weight_column_name` whose value is a `Tensor`.
+ - for each `column` in `feature_columns`:
+ - if `column` is a `SparseColumn`, a feature with `key=column.name`
+ whose `value` is a `SparseTensor`.
+ - if `column` is a `RealValuedColumn, a feature with `key=column.name`
+ whose `value` is a `Tensor`.
+ """
+
+ def __init__(self,
+ example_id_column,
+ feature_columns,
+ weight_column_name=None,
+ model_dir=None,
+ head=None,
+ l1_regularization=0.0,
+ l2_regularization=1.0,
+ num_loss_partitions=None,
+ config=None,
+ feature_engineering_fn=None):
+ """Construct a `_SDCAEstimator` estimator object.
+
+ Args:
+ example_id_column: A string defining the feature column name representing
+ example ids. Used to initialize the underlying SDCA optimizer.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`.
+ weight_column_name: A string defining feature column name representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example.
+ model_dir: Directory to save model parameters, graph etc. This can also be
+ used to load checkpoints from the directory into an estimator to
+ continue training a previously saved model.
+ head: type of head. Currently, _BinaryLogisticHead and _BinarySvmHead are
+ supported for classification and _RegressionHead for regression. It
+ should be a subclass of _SingleHead.
+ l1_regularization: L1-regularization parameter. Refers to global L1
+ regularization (across all examples).
+ l2_regularization: L2-regularization parameter. Refers to global L2
+ regularization (across all examples).
+ num_loss_partitions: number of partitions of the (global) loss function
+ optimized by the underlying optimizer (SDCAOptimizer).
+ config: `RunConfig` object to configure the runtime settings.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+
+ Returns:
+ A `_SDCAEstimator` estimator.
+
+ Raises:
+ ValueError: if head is not supported by SDCA.
+ """
+ self._feature_columns = tuple(feature_columns or [])
+ assert self._feature_columns
+
+ if not _head_is_valid_for_sdca(head):
+ raise ValueError(
+ "head type: {} is not supported. Supported head types: "
+ "_BinaryLogisticHead, _BinarySvmHead and _RegressionHead.".format(
+ type(head)))
+ assert head.logits_dimension == 1
+
+ params = {
+ "head": head,
+ "feature_columns": feature_columns,
+ "example_id_column": example_id_column,
+ "num_loss_partitions": num_loss_partitions,
+ "l1_regularization": l1_regularization,
+ "l2_regularization": l2_regularization,
+ "weight_column_name": weight_column_name,
+ "update_weights_hook": _SdcaUpdateWeightsHook(),
+ }
+
+ super(_SDCAEstimator, self).__init__(
+ model_fn=sdca_model_fn,
+ model_dir=model_dir,
+ config=config,
+ params=params,
+ feature_engineering_fn=feature_engineering_fn)
+
+
+class SDCALogisticClassifier(_SDCAEstimator):
+ """Logistic regression binary classifier using the SDCA optimizer.
+
+ Example usage:
+
+ ```python
+ sparse_column_a = sparse_column_with_hash_bucket(...)
+ sparse_column_b = sparse_column_with_hash_bucket(...)
+
+ sparse_feature_a_x_sparse_feature_b = crossed_column(...)
+
+ estimator = SDCALogisticClassifier(
+ example_id_column='example_id',
+ feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b]),
+ weight_column_name=...,
+ l2_regularization=...,
+ num_loss_partitions=...,
+ )
+
+ # Input builders
+ # returns x, y (where y is the label Tensor (with 0/1 values)
+ def input_fn_{train, eval}:
+
+ # returns x (features dict)
+ def input_fn_test:
+ ...
+ estimator.fit(input_fn=input_fn_train)
+ estimator.evaluate(input_fn=input_fn_eval)
+ estimator.predict_classes(input_fn=input_fn_test) # returns predicted classes.
+ estimator.predict_proba(input_fn=input_fn_test) # returns predicted prob/ties.
+ ```
+ """
+
+ def __init__(self,
+ example_id_column,
+ feature_columns,
+ weight_column_name=None,
+ model_dir=None,
+ l1_regularization=0.0,
+ l2_regularization=1.0,
+ num_loss_partitions=None,
+ config=None,
+ feature_engineering_fn=None):
+ """Construct a `SDCALogisticClassifier` object. See _SDCAEstimator."""
+ super(SDCALogisticClassifier, self).__init__(
+ example_id_column=example_id_column,
+ feature_columns=feature_columns,
+ weight_column_name=weight_column_name,
+ model_dir=model_dir,
+ head=head_lib.multi_class_head(
+ n_classes=2, weight_column_name=weight_column_name),
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ num_loss_partitions=num_loss_partitions,
+ config=config,
+ feature_engineering_fn=None)
+
+ def predict_classes(self, input_fn=None):
+ """Runs inference to determine the predicted class.
+
+ Args:
+ input_fn: The input function providing features.
+
+ Returns:
+ A generator of predicted classes for the features provided by input_fn.
+ """
+ key = prediction_key.PredictionKey.CLASSES
+ predictions = super(SDCALogisticClassifier, self).predict(
+ input_fn=input_fn, outputs=[key])
+ return (pred[key] for pred in predictions)
+
+ def predict_proba(self, input_fn=None):
+ """Runs inference to determine the class probability predictions.
+
+ Args:
+ input_fn: The input function providing features.
+
+ Returns:
+ A generator of predicted class probabilities for the features provided by
+ input_fn.
+ """
+ key = prediction_key.PredictionKey.PROBABILITIES
+ predictions = super(SDCALogisticClassifier, self).predict(
+ input_fn=input_fn, outputs=[key])
+ return (pred[key] for pred in predictions)
+
+
+class SDCARegressor(_SDCAEstimator):
+ """Linear regressor model using SDCA to solve the underlying optimization.
+
+ Example usage:
+
+ ```python
+ sparse_column_a = sparse_column_with_hash_bucket(...)
+ sparse_column_b = sparse_column_with_hash_bucket(...)
+
+ sparse_feature_a_x_sparse_feature_b = crossed_column(...)
+
+ estimator = SDCARegressor(
+ example_id_column='example_id',
+ feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b]),
+ weight_column_name=...,
+ l2_regularization=...,
+ num_loss_partitions=...,
+ )
+
+ # Input builders
+ # returns x, y (where y is the label Tensor (with 0/1 values)
+ def input_fn_{train, eval}:
+
+ # returns x (features dict)
+ def input_fn_test:
+ ...
+ estimator.fit(input_fn=input_fn_train)
+ estimator.evaluate(input_fn=input_fn_eval)
+ estimator.predict_scores(input_fn=input_fn_test) # returns predicted scores.
+ """
+
+ def __init__(self,
+ example_id_column,
+ feature_columns,
+ weight_column_name=None,
+ model_dir=None,
+ l1_regularization=0.0,
+ l2_regularization=1.0,
+ num_loss_partitions=None,
+ config=None,
+ feature_engineering_fn=None):
+ """Construct a `SDCARegressor` estimator object. See _SDCAEstimator."""
+ super(SDCARegressor, self).__init__(
+ example_id_column=example_id_column,
+ feature_columns=feature_columns,
+ weight_column_name=weight_column_name,
+ model_dir=model_dir,
+ head=head_lib.regression_head(weight_column_name=weight_column_name),
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ num_loss_partitions=num_loss_partitions,
+ config=config,
+ feature_engineering_fn=None)
+
+ def predict_scores(self, input_fn):
+ """Returns predicted scores for given features.
+
+ Args:
+ input_fn: The input function providing features.
+
+ Returns:
+ A generator of predicted scores for the features provided by input_fn.
+ """
+ key = prediction_key.PredictionKey.SCORES
+ predictions = super(SDCARegressor, self).predict(
+ input_fn=input_fn, outputs=[key])
+ return (pred[key] for pred in predictions)
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py
new file mode 100644
index 0000000000..081651df8d
--- /dev/null
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py
@@ -0,0 +1,502 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for linear_optimizer.sdca_estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
+from tensorflow.contrib.linear_optimizer.python.sdca_estimator import SDCALogisticClassifier
+from tensorflow.contrib.linear_optimizer.python.sdca_estimator import SDCARegressor
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+
+
+class SDCALogisticClassifierTest(test.TestCase):
+
+ def testRealValuedFeatures(self):
+ """Tests SDCALogisticClassifier works with real valued features."""
+
+ def input_fn():
+ return {
+ 'example_id': constant_op.constant(['1', '2']),
+ 'maintenance_cost': constant_op.constant([[500.0], [200.0]]),
+ 'sq_footage': constant_op.constant([[800.0], [600.0]]),
+ 'weights': constant_op.constant([[1.0], [1.0]])
+ }, constant_op.constant([[0], [1]])
+
+ maintenance_cost = feature_column_lib.real_valued_column('maintenance_cost')
+ sq_footage = feature_column_lib.real_valued_column('sq_footage')
+ classifier = SDCALogisticClassifier(
+ example_id_column='example_id',
+ feature_columns=[maintenance_cost, sq_footage],
+ weight_column_name='weights')
+ classifier.fit(input_fn=input_fn, steps=100)
+ loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
+ self.assertLess(loss, 0.05)
+
+ def testRealValuedFeatureWithHigherDimension(self):
+ """Tests SDCALogisticClassifier with high-dimension real valued features."""
+
+ # input_fn is identical to the one in testRealValuedFeatures where 2
+ # 1-dimensional dense features are replaced by a 2-dimensional feature.
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2']),
+ 'dense_feature':
+ constant_op.constant([[500.0, 800.0], [200.0, 600.0]])
+ }, constant_op.constant([[0], [1]])
+
+ dense_feature = feature_column_lib.real_valued_column(
+ 'dense_feature', dimension=2)
+ classifier = SDCALogisticClassifier(
+ example_id_column='example_id', feature_columns=[dense_feature])
+ classifier.fit(input_fn=input_fn, steps=100)
+ loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
+ self.assertLess(loss, 0.05)
+
+ def testBucketizedFeatures(self):
+ """Tests SDCALogisticClassifier with bucketized features."""
+
+ def input_fn():
+ return {
+ 'example_id': constant_op.constant(['1', '2', '3']),
+ 'price': constant_op.constant([[600.0], [1000.0], [400.0]]),
+ 'sq_footage': constant_op.constant([[1000.0], [600.0], [700.0]]),
+ 'weights': constant_op.constant([[1.0], [1.0], [1.0]])
+ }, constant_op.constant([[1], [0], [1]])
+
+ price_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('price'),
+ boundaries=[500.0, 700.0])
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'), boundaries=[650.0])
+ classifier = SDCALogisticClassifier(
+ example_id_column='example_id',
+ feature_columns=[price_bucket, sq_footage_bucket],
+ weight_column_name='weights',
+ l2_regularization=1.0)
+ classifier.fit(input_fn=input_fn, steps=50)
+ metrics = classifier.evaluate(input_fn=input_fn, steps=1)
+ self.assertGreater(metrics['accuracy'], 0.9)
+
+ def testSparseFeatures(self):
+ """Tests SDCALogisticClassifier with sparse features."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.4], [0.6], [0.3]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[1.0], [1.0], [1.0]])
+ }, constant_op.constant([[1], [0], [1]])
+
+ price = feature_column_lib.real_valued_column('price')
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ classifier = SDCALogisticClassifier(
+ example_id_column='example_id',
+ feature_columns=[price, country],
+ weight_column_name='weights')
+ classifier.fit(input_fn=input_fn, steps=50)
+ metrics = classifier.evaluate(input_fn=input_fn, steps=1)
+ self.assertGreater(metrics['accuracy'], 0.9)
+
+ def testWeightedSparseFeatures(self):
+ """Tests SDCALogisticClassifier with weighted sparse features."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[2., 3., 1.],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ dense_shape=[3, 5]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ dense_shape=[3, 5])
+ }, constant_op.constant([[1], [0], [1]])
+
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ country_weighted_by_price = feature_column_lib.weighted_sparse_column(
+ country, 'price')
+ classifier = SDCALogisticClassifier(
+ example_id_column='example_id',
+ feature_columns=[country_weighted_by_price])
+ classifier.fit(input_fn=input_fn, steps=50)
+ metrics = classifier.evaluate(input_fn=input_fn, steps=1)
+ self.assertGreater(metrics['accuracy'], 0.9)
+
+ def testCrossedFeatures(self):
+ """Tests SDCALogisticClassifier with crossed features."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'language':
+ sparse_tensor.SparseTensor(
+ values=['english', 'italian', 'spanish'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ dense_shape=[3, 1]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['US', 'IT', 'MX'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ dense_shape=[3, 1])
+ }, constant_op.constant([[0], [0], [1]])
+
+ language = feature_column_lib.sparse_column_with_hash_bucket(
+ 'language', hash_bucket_size=5)
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ country_language = feature_column_lib.crossed_column(
+ [language, country], hash_bucket_size=10)
+ classifier = SDCALogisticClassifier(
+ example_id_column='example_id', feature_columns=[country_language])
+ classifier.fit(input_fn=input_fn, steps=10)
+ metrics = classifier.evaluate(input_fn=input_fn, steps=1)
+ self.assertGreater(metrics['accuracy'], 0.9)
+
+ def testMixedFeatures(self):
+ """Tests SDCALogisticClassifier with a mix of features."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.6], [0.8], [0.3]]),
+ 'sq_footage':
+ constant_op.constant([[900.0], [700.0], [600.0]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [1.0], [1.0]])
+ }, constant_op.constant([[1], [0], [1]])
+
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+ classifier = SDCALogisticClassifier(
+ example_id_column='example_id',
+ feature_columns=[price, sq_footage_bucket, country, sq_footage_country],
+ weight_column_name='weights')
+ classifier.fit(input_fn=input_fn, steps=50)
+ metrics = classifier.evaluate(input_fn=input_fn, steps=1)
+ self.assertGreater(metrics['accuracy'], 0.9)
+
+
+class SDCARegressorTest(test.TestCase):
+
+ def testRealValuedLinearFeatures(self):
+ """Tests SDCARegressor works with real valued features."""
+ x = [[1.2, 2.0, -1.5], [-2.0, 3.0, -0.5], [1.0, -0.5, 4.0]]
+ weights = [[3.0], [-1.2], [0.5]]
+ y = np.dot(x, weights)
+
+ def input_fn():
+ return {
+ 'example_id': constant_op.constant(['1', '2', '3']),
+ 'x': constant_op.constant(x),
+ 'weights': constant_op.constant([[10.0], [10.0], [10.0]])
+ }, constant_op.constant(y)
+
+ x_column = feature_column_lib.real_valued_column('x', dimension=3)
+ regressor = SDCARegressor(
+ example_id_column='example_id',
+ feature_columns=[x_column],
+ weight_column_name='weights')
+ regressor.fit(input_fn=input_fn, steps=20)
+ loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
+ self.assertLess(loss, 0.01)
+ self.assertIn('linear/x/weight', regressor.get_variable_names())
+ regressor_weights = regressor.get_variable_value('linear/x/weight')
+ self.assertAllClose(
+ [w[0] for w in weights], regressor_weights.flatten(), rtol=0.1)
+
+ def testMixedFeaturesArbitraryWeights(self):
+ """Tests SDCARegressor works with a mix of features."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.6], [0.8], [0.3]]),
+ 'sq_footage':
+ constant_op.constant([[900.0], [700.0], [600.0]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [5.0], [7.0]])
+ }, constant_op.constant([[1.55], [-1.25], [-3.0]])
+
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+ regressor = SDCARegressor(
+ example_id_column='example_id',
+ feature_columns=[price, sq_footage_bucket, country, sq_footage_country],
+ l2_regularization=1.0,
+ weight_column_name='weights')
+ regressor.fit(input_fn=input_fn, steps=20)
+ loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
+ self.assertLess(loss, 0.05)
+
+ def testSdcaOptimizerSparseFeaturesWithL1Reg(self):
+ """Tests SDCARegressor works with sparse features and L1 regularization."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.4], [0.6], [0.3]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[10.0], [10.0], [10.0]])
+ }, constant_op.constant([[1.4], [-0.8], [2.6]])
+
+ price = feature_column_lib.real_valued_column('price')
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ # Regressor with no L1 regularization.
+ regressor = SDCARegressor(
+ example_id_column='example_id',
+ feature_columns=[price, country],
+ weight_column_name='weights')
+ regressor.fit(input_fn=input_fn, steps=20)
+ no_l1_reg_loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
+ variable_names = regressor.get_variable_names()
+ self.assertIn('linear/price/weight', variable_names)
+ self.assertIn('linear/country/weights', variable_names)
+ no_l1_reg_weights = {
+ 'linear/price/weight':
+ regressor.get_variable_value('linear/price/weight'),
+ 'linear/country/weights':
+ regressor.get_variable_value('linear/country/weights'),
+ }
+
+ # Regressor with L1 regularization.
+ regressor = SDCARegressor(
+ example_id_column='example_id',
+ feature_columns=[price, country],
+ l1_regularization=1.0,
+ weight_column_name='weights')
+ regressor.fit(input_fn=input_fn, steps=20)
+ l1_reg_loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
+ l1_reg_weights = {
+ 'linear/price/weight':
+ regressor.get_variable_value('linear/price/weight'),
+ 'linear/country/weights':
+ regressor.get_variable_value('linear/country/weights'),
+ }
+
+ # Unregularized loss is lower when there is no L1 regularization.
+ self.assertLess(no_l1_reg_loss, l1_reg_loss)
+ self.assertLess(no_l1_reg_loss, 0.05)
+
+ # But weights returned by the regressor with L1 regularization have smaller
+ # L1 norm.
+ l1_reg_weights_norm, no_l1_reg_weights_norm = 0.0, 0.0
+ for var_name in sorted(l1_reg_weights):
+ l1_reg_weights_norm += sum(
+ np.absolute(l1_reg_weights[var_name].flatten()))
+ no_l1_reg_weights_norm += sum(
+ np.absolute(no_l1_reg_weights[var_name].flatten()))
+ print('Var name: %s, value: %s' % (var_name,
+ no_l1_reg_weights[var_name].flatten()))
+ self.assertLess(l1_reg_weights_norm, no_l1_reg_weights_norm)
+
+ def testBiasOnly(self):
+ """Tests SDCARegressor has a valid bias weight."""
+
+ def input_fn():
+ """Testing the bias weight when it's the only feature present.
+
+ All of the instances in this input only have the bias feature, and a
+ 1/4 of the labels are positive. This means that the expected weight for
+ the bias should be close to the average prediction, i.e 0.25.
+ Returns:
+ Training data for the test.
+ """
+ num_examples = 40
+ return {
+ 'example_id':
+ constant_op.constant([str(x + 1) for x in range(num_examples)]),
+ # place_holder is an empty column which is always 0 (absent), because
+ # LinearClassifier requires at least one column.
+ 'place_holder':
+ constant_op.constant([[0.0]] * num_examples),
+ }, constant_op.constant([[1 if i % 4 is 0 else 0]
+ for i in range(num_examples)])
+
+ place_holder = feature_column_lib.real_valued_column('place_holder')
+ regressor = SDCARegressor(
+ example_id_column='example_id', feature_columns=[place_holder])
+ regressor.fit(input_fn=input_fn, steps=100)
+ self.assertNear(
+ regressor.get_variable_value('linear/bias_weight')[0], 0.25, err=0.1)
+
+ def testBiasAndOtherColumns(self):
+ """SDCARegressor has valid bias weight when other columns are present."""
+
+ def input_fn():
+ """Testing the bias weight when there are other features present.
+
+ 1/2 of the instances in this input have feature 'a', the rest have
+ feature 'b', and we expect the bias to be added to each instance as well.
+ 0.4 of all instances that have feature 'a' are positive, and 0.2 of all
+ instances that have feature 'b' are positive. The labels in the dataset
+ are ordered to appear shuffled since SDCA expects shuffled data, and
+ converges faster with this pseudo-random ordering.
+ If the bias was centered we would expect the weights to be:
+ bias: 0.3
+ a: 0.1
+ b: -0.1
+ Until b/29339026 is resolved, the bias gets regularized with the same
+ global value for the other columns, and so the expected weights get
+ shifted and are:
+ bias: 0.2
+ a: 0.2
+ b: 0.0
+ Returns:
+ The test dataset.
+ """
+ num_examples = 200
+ half = int(num_examples / 2)
+ return {
+ 'example_id':
+ constant_op.constant([str(x + 1) for x in range(num_examples)]),
+ 'a':
+ constant_op.constant([[1]] * int(half) + [[0]] * int(half)),
+ 'b':
+ constant_op.constant([[0]] * int(half) + [[1]] * int(half)),
+ }, constant_op.constant(
+ [[x]
+ for x in [1, 0, 0, 1, 1, 0, 0, 0, 1, 0] * int(half / 10) +
+ [0, 1, 0, 0, 0, 0, 0, 0, 1, 0] * int(half / 10)])
+
+ regressor = SDCARegressor(
+ example_id_column='example_id',
+ feature_columns=[
+ feature_column_lib.real_valued_column('a'),
+ feature_column_lib.real_valued_column('b')
+ ])
+
+ regressor.fit(input_fn=input_fn, steps=200)
+
+ variable_names = regressor.get_variable_names()
+ self.assertIn('linear/bias_weight', variable_names)
+ self.assertIn('linear/a/weight', variable_names)
+ self.assertIn('linear/b/weight', variable_names)
+ # TODO(b/29339026): Change the expected results to expect a centered bias.
+ self.assertNear(
+ regressor.get_variable_value('linear/bias_weight')[0], 0.2, err=0.05)
+ self.assertNear(
+ regressor.get_variable_value('linear/a/weight')[0], 0.2, err=0.05)
+ self.assertNear(
+ regressor.get_variable_value('linear/b/weight')[0], 0.0, err=0.05)
+
+ def testBiasAndOtherColumnsFabricatedCentered(self):
+ """SDCARegressor has valid bias weight when instances are centered."""
+
+ def input_fn():
+ """Testing the bias weight when there are other features present.
+
+ 1/2 of the instances in this input have feature 'a', the rest have
+ feature 'b', and we expect the bias to be added to each instance as well.
+ 0.1 of all instances that have feature 'a' have a label of 1, and 0.1 of
+ all instances that have feature 'b' have a label of -1.
+ We can expect the weights to be:
+ bias: 0.0
+ a: 0.1
+ b: -0.1
+ Returns:
+ The test dataset.
+ """
+ num_examples = 200
+ half = int(num_examples / 2)
+ return {
+ 'example_id':
+ constant_op.constant([str(x + 1) for x in range(num_examples)]),
+ 'a':
+ constant_op.constant([[1]] * int(half) + [[0]] * int(half)),
+ 'b':
+ constant_op.constant([[0]] * int(half) + [[1]] * int(half)),
+ }, constant_op.constant([[1 if x % 10 == 0 else 0] for x in range(half)] +
+ [[-1 if x % 10 == 0 else 0] for x in range(half)])
+
+ regressor = SDCARegressor(
+ example_id_column='example_id',
+ feature_columns=[
+ feature_column_lib.real_valued_column('a'),
+ feature_column_lib.real_valued_column('b')
+ ])
+
+ regressor.fit(input_fn=input_fn, steps=100)
+
+ variable_names = regressor.get_variable_names()
+ self.assertIn('linear/bias_weight', variable_names)
+ self.assertIn('linear/a/weight', variable_names)
+ self.assertIn('linear/b/weight', variable_names)
+ self.assertNear(
+ regressor.get_variable_value('linear/bias_weight')[0], 0.0, err=0.05)
+ self.assertNear(
+ regressor.get_variable_value('linear/a/weight')[0], 0.1, err=0.05)
+ self.assertNear(
+ regressor.get_variable_value('linear/b/weight')[0], -0.1, err=0.05)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
index 9edb00e7b0..afa0b3b833 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
@@ -74,6 +74,26 @@ class SDCAOptimizer(object):
def get_name(self):
return 'SDCAOptimizer'
+ @property
+ def example_id_column(self):
+ return self._example_id_column
+
+ @property
+ def num_loss_partitions(self):
+ return self._num_loss_partitions
+
+ @property
+ def num_table_shards(self):
+ return self._num_table_shards
+
+ @property
+ def symmetric_l1_regularization(self):
+ return self._symmetric_l1_regularization
+
+ @property
+ def symmetric_l2_regularization(self):
+ return self._symmetric_l2_regularization
+
def get_train_step(self, columns_to_variables,
weight_column_name, loss_type, features, targets,
global_step):