aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-23 15:46:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 15:48:49 -0700
commit78c3a8870d2f748f356415e8d7acf9748d09c197 (patch)
treec4c0eab4e79e14e1ddfa3e9b77de42ea1b2c96f8 /tensorflow/contrib/linear_optimizer/python
parentf504a2445051c4c48eb9edd6a023b1f33a2793f2 (diff)
Add support for partitioned variables to SDCA.
PiperOrigin-RevId: 197803127
Diffstat (limited to 'tensorflow/contrib/linear_optimizer/python')
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py71
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py252
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_estimator.py29
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py84
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py29
5 files changed, 396 insertions, 69 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index b5741967ab..d0c32b43cc 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -35,6 +35,8 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_sdca_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import googletest
@@ -132,15 +134,22 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
return examples_dict, variables_dict
-def make_variable_dict(max_age, max_gender):
+def make_variable_dict(max_age, max_gender, partitioned=False):
# TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from
# examples_dict.
- age_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_age + 1], dtype=dtypes.float32))
- gender_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_gender + 1], dtype=dtypes.float32))
+ partitioner = None
+ if partitioned:
+ partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2,
+ axis=0)
+ with variable_scope.variable_scope(
+ name_or_scope='variables',
+ partitioner=partitioner):
+ age_weights = variables_lib.Variable(
+ array_ops.zeros(
+ [max_age + 1], dtype=dtypes.float32))
+ gender_weights = variables_lib.Variable(
+ array_ops.zeros(
+ [max_gender + 1], dtype=dtypes.float32))
return dict(
sparse_features_weights=[age_weights, gender_weights],
dense_features_weights=[])
@@ -265,6 +274,54 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
self.assertAllClose(
0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+ def testPartitionedPrimals(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0],
+ 'gender': [0]
+ }, 0),
+ make_example_proto({
+ 'age': [1],
+ 'gender': [1]
+ }, 1),
+ ]
+ example_weights = [1.0, 1.0]
+ for num_shards in _SHARD_NUMBERS:
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ variables = make_variable_dict(1, 1, partitioned=True)
+ options = dict(
+ symmetric_l2_regularization=1,
+ symmetric_l1_regularization=0,
+ num_table_shards=num_shards,
+ loss_type='logistic_loss')
+
+ lr = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+ unregularized_loss = lr.unregularized_loss(examples)
+ loss = lr.regularized_loss(examples)
+ predictions = lr.predictions(examples)
+ self.assertAllClose(0.693147, unregularized_loss.eval())
+ self.assertAllClose(0.693147, loss.eval())
+ train_op = lr.minimize()
+ for _ in range(_MAX_ITERATIONS):
+ train_op.run()
+ lr.update_weights(train_op).run()
+ # The high tolerance in unregularized_loss comparisons is due to the
+ # fact that it's possible to trade off unregularized_loss vs.
+ # regularization and still have a sum that is quite close to the
+ # optimal regularized_loss value. SDCA's duality gap only ensures that
+ # the regularized_loss is within 0.01 of optimal.
+ # 0.525457 is the optimal regularized_loss.
+ # 0.411608 is the unregularized_loss at that optimum.
+ self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05)
+ self.assertAllClose(0.525457, loss.eval(), atol=0.01)
+ predicted_labels = get_binary_predictions_for_logistic(predictions)
+ self.assertAllEqual([0, 1], predicted_labels.eval())
+ self.assertAllClose(
+ 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+
def testSparseRandom(self):
dim = 20
num_examples = 1000
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index f980746a19..0047d5753a 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -22,12 +22,14 @@ import collections
from six.moves import range
from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework.ops import internal_convert_to_tensor
from tensorflow.python.framework.ops import name_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_sdca_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@@ -43,9 +45,6 @@ __all__ = ['SdcaModel']
class SdcaModel(object):
"""Stochastic dual coordinate ascent solver for linear models.
- This class currently only supports a single machine (multi-threaded)
- implementation. We expect the weights and duals to fit in a single machine.
-
Loss functions supported:
* Binary logistic loss
@@ -182,18 +181,41 @@ class SdcaModel(object):
# TODO(sibyl-Aix6ihai): Use optimizer interface to make use of slot creation logic.
def _create_slots(self):
- # Make internal variables which have the updates before applying L1
- # regularization.
+ """Make unshrinked internal variables (slots)."""
+ # Unshrinked variables have the updates before applying L1 regularization.
+ # Each unshrinked slot variable is either a `Variable` or list of
+ # `Variable`, depending on the value of its corresponding primary variable.
+ # We avoid using `PartitionedVariable` for the unshrinked slots since we do
+ # not need any of the extra information.
self._slots = collections.defaultdict(list)
for name in ['sparse_features_weights', 'dense_features_weights']:
for var in self._variables[name]:
- with ops.device(var.device):
- # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is
- # fixed
- self._slots['unshrinked_' + name].append(
- var_ops.Variable(
- array_ops.zeros_like(var.initialized_value(), dtypes.float32),
- name=var.op.name + '_unshrinked/SDCAOptimizer'))
+ # Our primary variable may be either a PartitionedVariable, or a list
+ # of Variables (each representing a partition).
+ if (isinstance(var, var_ops.PartitionedVariable) or
+ isinstance(var, list)):
+ var_list = []
+ # pylint: disable=protected-access
+ for v in var:
+ with ops.colocate_with(v):
+ # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109
+ # is fixed.
+ slot_var = var_ops.Variable(
+ initial_value=array_ops.zeros_like(v.initialized_value(),
+ dtypes.float32),
+ name=v.op.name + '_unshrinked/SDCAOptimizer')
+ var_list.append(slot_var)
+ self._slots['unshrinked_' + name].append(var_list)
+ # pylint: enable=protected-access
+ else:
+ with ops.device(var.device):
+ # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is
+ # fixed.
+ self._slots['unshrinked_' + name].append(
+ var_ops.Variable(
+ array_ops.zeros_like(var.initialized_value(),
+ dtypes.float32),
+ name=var.op.name + '_unshrinked/SDCAOptimizer'))
def _assertSpecified(self, items, check_in):
for x in items:
@@ -205,16 +227,25 @@ class SdcaModel(object):
if not isinstance(check_in[x], list):
raise ValueError(x + ' must be a list.')
+ def _var_to_list(self, var):
+ """Wraps var in a list if it is not a list or PartitionedVariable."""
+ if not (isinstance(var, list) or
+ isinstance(var, var_ops.PartitionedVariable)):
+ var = [var]
+ return var
+
def _l1_loss(self):
"""Computes the (un-normalized) l1 loss of the model."""
with name_scope('sdca/l1_loss'):
sums = []
for name in ['sparse_features_weights', 'dense_features_weights']:
- for weights in self._convert_n_to_tensor(self._variables[name]):
- with ops.device(weights.device):
- sums.append(
- math_ops.reduce_sum(
- math_ops.abs(math_ops.cast(weights, dtypes.float64))))
+ for var in self._variables[name]:
+ for v in self._var_to_list(var):
+ weights = internal_convert_to_tensor(v)
+ with ops.device(weights.device):
+ sums.append(
+ math_ops.reduce_sum(
+ math_ops.abs(math_ops.cast(weights, dtypes.float64))))
# SDCA L1 regularization cost is: l1 * sum(|weights|)
return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums)
@@ -223,17 +254,37 @@ class SdcaModel(object):
with name_scope('sdca/l2_loss'):
sums = []
for name in ['sparse_features_weights', 'dense_features_weights']:
- for weights in self._convert_n_to_tensor(self._variables[name]):
- with ops.device(weights.device):
- sums.append(
- math_ops.reduce_sum(
- math_ops.square(math_ops.cast(weights, dtypes.float64))))
+ for var in self._variables[name]:
+ for v in self._var_to_list(var):
+ weights = internal_convert_to_tensor(v)
+ with ops.device(weights.device):
+ sums.append(math_ops.reduce_sum(math_ops.square(math_ops.cast(
+ weights, dtypes.float64))))
# SDCA L2 regularization cost is: l2 * sum(weights^2) / 2
return l2 * math_ops.add_n(sums) / 2.0
def _convert_n_to_tensor(self, input_list, as_ref=False):
"""Converts input list to a set of tensors."""
- return [internal_convert_to_tensor(x, as_ref=as_ref) for x in input_list]
+ # input_list can be a list of Variables (that are implicitly partitioned),
+ # in which case the underlying logic in internal_convert_to_tensor will not
+ # concatenate the partitions together. This method takes care of the
+ # concatenating (we only allow partitioning on the first axis).
+ output_list = []
+ for x in input_list:
+ tensor_to_convert = x
+ if isinstance(x, list) or isinstance(x, var_ops.PartitionedVariable):
+ # We only allow for partitioning on the first axis.
+ tensor_to_convert = array_ops.concat(x, axis=0)
+ output_list.append(internal_convert_to_tensor(
+ tensor_to_convert, as_ref=as_ref))
+ return output_list
+
+ def _get_first_dimension_size_statically(self, w, num_partitions):
+ """Compute the static size of the first dimension for a sharded variable."""
+ dim_0_size = w[0].get_shape()[0]
+ for p in range(1, num_partitions):
+ dim_0_size += w[p].get_shape()[0]
+ return dim_0_size
def _linear_predictions(self, examples):
"""Returns predictions of the form w*x."""
@@ -286,6 +337,28 @@ class SdcaModel(object):
result = math_ops.sigmoid(result)
return result
+ def _get_partitioned_update_ops(self,
+ v_num,
+ num_partitions_by_var,
+ p_assignments_by_var,
+ gather_ids_by_var,
+ weights,
+ full_update,
+ p_assignments,
+ num_partitions):
+ """Get updates for partitioned variables."""
+ num_partitions = num_partitions_by_var[v_num]
+ p_assignments = p_assignments_by_var[v_num]
+ gather_ids = gather_ids_by_var[v_num]
+ updates = data_flow_ops.dynamic_partition(
+ full_update, p_assignments, num_partitions)
+ update_ops = []
+ for p in range(num_partitions):
+ with ops.colocate_with(weights[p]):
+ result = state_ops.scatter_add(weights[p], gather_ids[p], updates[p])
+ update_ops.append(result)
+ return update_ops
+
def minimize(self, global_step=None, name=None):
"""Add operations to train a linear model by minimizing the loss function.
@@ -318,18 +391,89 @@ class SdcaModel(object):
# Solver returns example_state_update, new delta sparse_feature_weights
# and delta dense_feature_weights.
- weights_tensor = self._convert_n_to_tensor(self._slots[
- 'unshrinked_sparse_features_weights'])
sparse_weights = []
sparse_indices = []
- for w, i in zip(weights_tensor, sparse_feature_indices):
- # Find the feature ids to lookup in the variables.
- with ops.device(w.device):
- sparse_indices.append(
- math_ops.cast(
- array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
- dtypes.int64))
- sparse_weights.append(array_ops.gather(w, sparse_indices[-1]))
+ # If we have partitioned variables, keep a few lists of Tensors around
+ # that we need for the assign_add after the op call to
+ # gen_sdca_ops.sdca_optimizer().
+ num_partitions_by_var = []
+ p_assignments_by_var = []
+ gather_ids_by_var = []
+ for w, i in zip(self._slots['unshrinked_sparse_features_weights'],
+ sparse_feature_indices):
+ # Append the sparse_indices (in full-variable space).
+ sparse_idx = math_ops.cast(
+ array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
+ dtypes.int64)
+ sparse_indices.append(sparse_idx)
+ if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable):
+ num_partitions = len(w)
+ flat_ids = array_ops.reshape(sparse_idx, [-1])
+ # We use div partitioning, which is easiest to support downstream.
+ # Compute num_total_ids as the sum of dim-0 of w, then assign
+ # to partitions based on a constant number of ids per partition.
+ # Optimize if we already know the full shape statically.
+ dim_0_size = self._get_first_dimension_size_statically(
+ w, num_partitions)
+
+ if dim_0_size.value:
+ num_total_ids = constant_op.constant(dim_0_size.value,
+ flat_ids.dtype)
+ else:
+ dim_0_sizes = []
+ for p in range(num_partitions):
+ if w[p].get_shape()[0].value is not None:
+ dim_0_sizes.append(w[p].get_shape()[0].value)
+ else:
+ with ops.colocate_with(w[p]):
+ dim_0_sizes.append(array_ops.shape(w[p])[0])
+ num_total_ids = math_ops.reduce_sum(
+ math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
+ ids_per_partition = num_total_ids // num_partitions
+ extras = num_total_ids % num_partitions
+
+ p_assignments = math_ops.maximum(
+ flat_ids // (ids_per_partition + 1),
+ (flat_ids - extras) // ids_per_partition)
+
+ # Emulate a conditional using a boolean indicator tensor
+ new_ids = array_ops.where(p_assignments < extras,
+ flat_ids % (ids_per_partition + 1),
+ (flat_ids - extras) % ids_per_partition)
+
+ # Cast partition assignments to int32 for use in dynamic_partition.
+ # There really should not be more than 2^32 partitions.
+ p_assignments = math_ops.cast(p_assignments, dtypes.int32)
+ # Partition list of ids based on assignments into num_partitions
+ # separate lists.
+ gather_ids = data_flow_ops.dynamic_partition(new_ids,
+ p_assignments,
+ num_partitions)
+ # Append these to the lists for use in the later update.
+ num_partitions_by_var.append(num_partitions)
+ p_assignments_by_var.append(p_assignments)
+ gather_ids_by_var.append(gather_ids)
+
+ # Gather the weights from each partition.
+ partition_gathered_weights = []
+ for p in range(num_partitions):
+ with ops.colocate_with(w[p]):
+ partition_gathered_weights.append(
+ array_ops.gather(w[p], gather_ids[p]))
+
+ # Stitch the weights back together in the same order they were before
+ # we dynamic_partitioned them.
+ condition_indices = data_flow_ops.dynamic_partition(
+ math_ops.range(array_ops.shape(new_ids)[0]),
+ p_assignments, num_partitions)
+ batch_gathered_weights = data_flow_ops.dynamic_stitch(
+ condition_indices, partition_gathered_weights)
+ else:
+ w_as_tensor = internal_convert_to_tensor(w)
+ with ops.device(w_as_tensor.device):
+ batch_gathered_weights = array_ops.gather(
+ w_as_tensor, sparse_idx)
+ sparse_weights.append(batch_gathered_weights)
# pylint: disable=protected-access
esu, sfw, dfw = gen_sdca_ops.sdca_optimizer(
@@ -355,12 +499,25 @@ class SdcaModel(object):
with ops.control_dependencies([esu]):
update_ops = [self._hashtable.insert(example_ids_hashed, esu)]
# Update the weights before the proximal step.
- for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'],
- sparse_indices, sfw):
- update_ops.append(state_ops.scatter_add(w, i, u))
+ for v_num, (w, i, u) in enumerate(
+ zip(self._slots['unshrinked_sparse_features_weights'],
+ sparse_indices, sfw)):
+ if (isinstance(w, var_ops.PartitionedVariable) or
+ isinstance(w, list)):
+ update_ops += self._get_partitioned_update_ops(
+ v_num, num_partitions_by_var, p_assignments_by_var,
+ gather_ids_by_var, w, u, p_assignments, num_partitions)
+ else:
+ update_ops.append(state_ops.scatter_add(w, i, u))
for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw):
- update_ops.append(w.assign_add(u))
-
+ if (isinstance(w, var_ops.PartitionedVariable) or
+ isinstance(w, list)):
+ split_updates = array_ops.split(
+ u, num_or_size_splits=[v.shape.as_list()[0] for v in w])
+ for v, split_update in zip(w, split_updates):
+ update_ops.append(state_ops.assign_add(v, split_update))
+ else:
+ update_ops.append(state_ops.assign_add(w, u))
if not global_step:
return control_flow_ops.group(*update_ops)
with ops.control_dependencies(update_ops):
@@ -385,21 +542,22 @@ class SdcaModel(object):
for name in ['sparse_features_weights', 'dense_features_weights']:
for var, slot_var in zip(self._variables[name],
self._slots['unshrinked_' + name]):
- update_ops.append(var.assign(slot_var))
+ for v, sv in zip(self._var_to_list(var), self._var_to_list(slot_var)):
+ update_ops.append(v.assign(sv))
# Apply proximal step.
with ops.control_dependencies(update_ops):
update_ops = []
for name in ['sparse_features_weights', 'dense_features_weights']:
for var in self._variables[name]:
- with ops.device(var.device):
- # pylint: disable=protected-access
- update_ops.append(
- gen_sdca_ops.sdca_shrink_l1(
- self._convert_n_to_tensor(
- [var], as_ref=True),
- l1=self._symmetric_l1_regularization(),
- l2=self._symmetric_l2_regularization()))
+ for v in self._var_to_list(var):
+ with ops.device(v.device):
+ # pylint: disable=protected-access
+ update_ops.append(
+ gen_sdca_ops.sdca_shrink_l1(
+ self._convert_n_to_tensor([v], as_ref=True),
+ l1=self._symmetric_l1_regularization(),
+ l2=self._symmetric_l2_regularization()))
return control_flow_ops.group(*update_ops)
def approximate_duality_gap(self):
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py
index d4e54c82f9..200e7de6b9 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py
@@ -116,6 +116,7 @@ def sdca_model_fn(features, labels, mode, params, config=None):
num_loss_partitions = params["num_loss_partitions"]
weight_column_name = params["weight_column_name"]
update_weights_hook = params.get("update_weights_hook", None)
+ partitioner = params["partitioner"]
loss_type = None
if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access
@@ -136,12 +137,14 @@ def sdca_model_fn(features, labels, mode, params, config=None):
example_id_column=example_id_column,
num_loss_partitions=n_loss_partitions,
symmetric_l1_regularization=l1_regularization,
- symmetric_l2_regularization=l2_regularization)
+ symmetric_l2_regularization=l2_regularization,
+ partitioner=partitioner)
parent_scope = "linear"
with variable_scope.variable_scope(
- values=features.values(), name_or_scope=parent_scope) as scope:
+ values=features.values(), name_or_scope=parent_scope,
+ partitioner=partitioner) as scope:
features = features.copy()
features.update(layers.transform_features(features, feature_columns))
logits, columns_to_variables, bias = (
@@ -213,7 +216,8 @@ class _SDCAEstimator(estimator.Estimator):
l2_regularization=1.0,
num_loss_partitions=None,
config=None,
- feature_engineering_fn=None):
+ feature_engineering_fn=None,
+ partitioner=None):
"""Construct a `_SDCAEstimator` estimator object.
Args:
@@ -241,6 +245,8 @@ class _SDCAEstimator(estimator.Estimator):
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
+ partitioner: Variable partitioner for the primal weights (`div`
+ partitioning strategy will be used).
Returns:
A `_SDCAEstimator` estimator.
@@ -267,6 +273,7 @@ class _SDCAEstimator(estimator.Estimator):
"l2_regularization": l2_regularization,
"weight_column_name": weight_column_name,
"update_weights_hook": _SdcaUpdateWeightsHook(),
+ "partitioner": partitioner,
}
super(_SDCAEstimator, self).__init__(
@@ -336,7 +343,8 @@ class SDCALogisticClassifier(_SDCAEstimator):
l2_regularization=1.0,
num_loss_partitions=None,
config=None,
- feature_engineering_fn=None):
+ feature_engineering_fn=None,
+ partitioner=None):
"""Construct a `SDCALogisticClassifier` object.
Args:
@@ -361,6 +369,8 @@ class SDCALogisticClassifier(_SDCAEstimator):
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
+ partitioner: Variable partitioner for the primal weights (`div`
+ partitioning strategy will be used).
Returns:
A `SDCALogisiticClassifier` estimator.
@@ -376,7 +386,8 @@ class SDCALogisticClassifier(_SDCAEstimator):
l2_regularization=l2_regularization,
num_loss_partitions=num_loss_partitions,
config=config,
- feature_engineering_fn=None)
+ feature_engineering_fn=None,
+ partitioner=partitioner)
def predict_classes(self, input_fn=None):
"""Runs inference to determine the predicted class.
@@ -463,7 +474,8 @@ class SDCALinearRegressor(_SDCAEstimator):
l2_regularization=1.0,
num_loss_partitions=None,
config=None,
- feature_engineering_fn=None):
+ feature_engineering_fn=None,
+ partitioner=None):
"""Construct a `SDCALinearRegressor` estimator object.
@@ -489,6 +501,8 @@ class SDCALinearRegressor(_SDCAEstimator):
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
+ partitioner: Variable partitioner for the primal weights (`div`
+ partitioning strategy will be used).
Returns:
A `SDCALinearRegressor` estimator.
@@ -503,7 +517,8 @@ class SDCALinearRegressor(_SDCAEstimator):
l2_regularization=l2_regularization,
num_loss_partitions=num_loss_partitions,
config=config,
- feature_engineering_fn=None)
+ feature_engineering_fn=None,
+ partitioner=partitioner)
def predict_scores(self, input_fn):
"""Returns predicted scores for given features.
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py
index bed3d5139f..6476671882 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py
@@ -25,6 +25,7 @@ from tensorflow.contrib.linear_optimizer.python import sdca_estimator
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import test
@@ -273,6 +274,47 @@ class SDCALogisticClassifierTest(test.TestCase):
metrics = classifier.evaluate(input_fn=input_fn, steps=1)
self.assertGreater(metrics['accuracy'], 0.9)
+ def testPartitionedMixedFeatures(self):
+ """Tests SDCALogisticClassifier with a mix of features (partitioned)."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.6], [0.8], [0.3]]),
+ 'sq_footage':
+ constant_op.constant([900.0, 700.0, 600.0]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [1.0], [1.0]])
+ }, constant_op.constant([[1], [0], [1]])
+
+ with self._single_threaded_test_session():
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+ classifier = sdca_estimator.SDCALogisticClassifier(
+ example_id_column='example_id',
+ feature_columns=[
+ price, sq_footage_bucket, country, sq_footage_country
+ ],
+ weight_column_name='weights',
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0))
+ classifier.fit(input_fn=input_fn, steps=50)
+ metrics = classifier.evaluate(input_fn=input_fn, steps=1)
+ self.assertGreater(metrics['accuracy'], 0.9)
+
class SDCALinearRegressorTest(test.TestCase):
@@ -350,6 +392,48 @@ class SDCALinearRegressorTest(test.TestCase):
loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
self.assertLess(loss, 0.05)
+ def testMixedFeaturesArbitraryWeightsPartitioned(self):
+ """Tests SDCALinearRegressor works with a mix of features (partitioned)."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.6], [0.8], [0.3]]),
+ 'sq_footage':
+ constant_op.constant([[900.0], [700.0], [600.0]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [5.0], [7.0]])
+ }, constant_op.constant([[1.55], [-1.25], [-3.0]])
+
+ with self._single_threaded_test_session():
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+ regressor = sdca_estimator.SDCALinearRegressor(
+ example_id_column='example_id',
+ feature_columns=[
+ price, sq_footage_bucket, country, sq_footage_country
+ ],
+ l2_regularization=1.0,
+ weight_column_name='weights',
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0))
+ regressor.fit(input_fn=input_fn, steps=20)
+ loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
+ self.assertLess(loss, 0.05)
+
def testSdcaOptimizerSparseFeaturesWithL1Reg(self):
"""SDCALinearRegressor works with sparse features and L1 regularization."""
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
index 12039ecc6f..9872c6f97c 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
@@ -64,7 +64,8 @@ class SDCAOptimizer(object):
of workers running the train steps. It defaults to 1 (single machine).
`num_table_shards` defines the number of shards for the internal state
table, typically set to match the number of parameter servers for large
- data sets.
+ data sets. You can also specify a `partitioner` object to partition the primal
+ weights during training (`div` partitioning strategy will be used).
"""
def __init__(self,
@@ -73,13 +74,15 @@ class SDCAOptimizer(object):
num_table_shards=None,
symmetric_l1_regularization=0.0,
symmetric_l2_regularization=1.0,
- adaptive=True):
+ adaptive=True,
+ partitioner=None):
self._example_id_column = example_id_column
self._num_loss_partitions = num_loss_partitions
self._num_table_shards = num_table_shards
self._symmetric_l1_regularization = symmetric_l1_regularization
self._symmetric_l2_regularization = symmetric_l2_regularization
self._adaptive = adaptive
+ self._partitioner = partitioner
def get_name(self):
return 'SDCAOptimizer'
@@ -108,6 +111,10 @@ class SDCAOptimizer(object):
def adaptive(self):
return self._adaptive
+ @property
+ def partitioner(self):
+ return self._partitioner
+
def get_train_step(self, columns_to_variables, weight_column_name, loss_type,
features, targets, global_step):
"""Returns the training operation of an SdcaModel optimizer."""
@@ -175,10 +182,12 @@ class SDCAOptimizer(object):
sparse_feature_column = _dense_tensor_to_sparse_feature_column(
dense_bucket_tensor)
sparse_feature_with_values.append(sparse_feature_column)
- # For bucketized columns, the variables list contains exactly one
- # element.
- sparse_feature_with_values_weights.append(
- columns_to_variables[column][0])
+ # If a partitioner was used during variable creation, we will have a
+ # list of Variables here larger than 1.
+ vars_to_append = columns_to_variables[column][0]
+ if len(columns_to_variables[column]) > 1:
+ vars_to_append = columns_to_variables[column]
+ sparse_feature_with_values_weights.append(vars_to_append)
elif isinstance(
column,
(
@@ -226,8 +235,12 @@ class SDCAOptimizer(object):
array_ops.shape(ids)[0]), [-1])
sparse_feature_with_values.append(
SparseFeatureColumn(example_ids_filtered, reproject_ids, weights))
- sparse_feature_with_values_weights.append(
- columns_to_variables[column][0])
+ # If a partitioner was used during variable creation, we will have a
+ # list of Variables here larger than 1.
+ vars_to_append = columns_to_variables[column][0]
+ if len(columns_to_variables[column]) > 1:
+ vars_to_append = columns_to_variables[column]
+ sparse_feature_with_values_weights.append(vars_to_append)
else:
raise ValueError('SDCAOptimizer does not support column type %s.' %
type(column).__name__)