aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py')
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py138
1 files changed, 74 insertions, 64 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 957a734b07..9d41f024ae 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -28,7 +28,6 @@ from tensorflow.python.framework.ops import name_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as var_ops
from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
from tensorflow.python.platform import resource_loader
@@ -139,30 +138,35 @@ class SdcaModel(object):
['loss_type', 'symmetric_l2_regularization',
'symmetric_l1_regularization'], options)
+ for name in ['symmetric_l1_regularization', 'symmetric_l2_regularization']:
+ value = options[name]
+ if value < 0.0:
+ raise ValueError('%s should be non-negative. Found (%f)' %
+ (name, value))
+
self._container = container
self._examples = examples
self._variables = variables
self._options = options
self._solver_uuid = uuid.uuid4().hex
- self._create_slots(variables)
-
- # TODO(rohananil): Use optimizer interface to make use of slot creation
- # logic
- def _create_slots(self, variables):
- self._slots = {}
- # TODO(rohananil): Rename the slot keys to "unshrinked" weights.
- self._slots['sparse_features_weights'] = []
- self._slots['dense_features_weights'] = []
- self._assign_ops = []
- # Make an internal variable which has the updates before applying L1
+ self._create_slots()
+
+ def _symmetric_l2_regularization(self):
+ # Algorithmic requirement (for now) is to have minimal l2 of 1.0
+ return max(self._options['symmetric_l2_regularization'], 1.0)
+
+ # TODO(rohananil): Use optimizer interface to make use of slot creation logic.
+ def _create_slots(self):
+ # Make internal variables which have the updates before applying L1
# regularization.
- for var_type in ['sparse_features_weights', 'dense_features_weights']:
- for var in variables[var_type]:
- if var is not None:
- self._slots[var_type].append(var_ops.Variable(array_ops.zeros_like(
- var.initialized_value(), dtypes.float32)))
- self._assign_ops.append(state_ops.assign(var, self._slots[var_type][
- -1]))
+ self._slots = {
+ 'unshrinked_sparse_features_weights': [],
+ 'unshrinked_dense_features_weights': [],
+ }
+ for name in ['sparse_features_weights', 'dense_features_weights']:
+ for var in self._variables[name]:
+ self._slots['unshrinked_' + name].append(var_ops.Variable(
+ array_ops.zeros_like(var.initialized_value(), dtypes.float32)))
def _assertSpecified(self, items, check_in):
for x in items:
@@ -177,33 +181,22 @@ class SdcaModel(object):
def _l1_loss(self):
"""Computes the l1 loss of the model."""
with name_scope('l1_loss'):
- sparse_weights = self._convert_n_to_tensor(self._variables[
- 'sparse_features_weights'])
- dense_weights = self._convert_n_to_tensor(self._variables[
- 'dense_features_weights'])
- l1 = self._options['symmetric_l1_regularization']
- loss = 0.0
- for w in sparse_weights:
- loss += l1 * math_ops.reduce_sum(abs(w))
- for w in dense_weights:
- loss += l1 * math_ops.reduce_sum(abs(w))
- return loss
-
- def _l2_loss(self):
+ sum = 0.0
+ for name in ['sparse_features_weights', 'dense_features_weights']:
+ for weights in self._convert_n_to_tensor(self._variables[name]):
+ sum += math_ops.reduce_sum(math_ops.abs(weights))
+ # SDCA L1 regularization cost is: l1 * sum(|weights|)
+ return self._options['symmetric_l1_regularization'] * sum
+
+ def _l2_loss(self, l2):
"""Computes the l2 loss of the model."""
with name_scope('l2_loss'):
- sparse_weights = self._convert_n_to_tensor(self._variables[
- 'sparse_features_weights'])
- dense_weights = self._convert_n_to_tensor(self._variables[
- 'dense_features_weights'])
- l2 = self._options['symmetric_l2_regularization']
- loss = 0.0
- for w in sparse_weights:
- loss += l2 * math_ops.reduce_sum(math_ops.square(w))
- for w in dense_weights:
- loss += l2 * math_ops.reduce_sum(math_ops.square(w))
- # SDCA L2 regularization cost is 1/2 * l2 * sum(weights^2)
- return loss / 2.0
+ sum = 0.0
+ for name in ['sparse_features_weights', 'dense_features_weights']:
+ for weights in self._convert_n_to_tensor(self._variables[name]):
+ sum += math_ops.reduce_sum(math_ops.square(weights))
+ # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2
+ return l2 * sum / 2
def _convert_n_to_tensor(self, input_list, as_ref=False):
"""Converts input list to a set of tensors."""
@@ -265,31 +258,44 @@ class SdcaModel(object):
"""
with name_scope('sdca/minimize'):
sparse_features_indices = []
- sparse_features_weights = []
+ sparse_features_values = []
for sf in self._examples['sparse_features']:
sparse_features_indices.append(convert_to_tensor(sf.indices))
- sparse_features_weights.append(convert_to_tensor(sf.values))
+ sparse_features_values.append(convert_to_tensor(sf.values))
step_op = _sdca_ops.sdca_solver(
sparse_features_indices,
- sparse_features_weights,
+ sparse_features_values,
self._convert_n_to_tensor(self._examples['dense_features']),
convert_to_tensor(self._examples['example_weights']),
convert_to_tensor(self._examples['example_labels']),
convert_to_tensor(self._examples['example_ids']),
- self._convert_n_to_tensor(self._slots['sparse_features_weights'],
- as_ref=True),
- self._convert_n_to_tensor(self._slots['dense_features_weights'],
- as_ref=True),
+ self._convert_n_to_tensor(
+ self._slots['unshrinked_sparse_features_weights'],
+ as_ref=True),
+ self._convert_n_to_tensor(
+ self._slots['unshrinked_dense_features_weights'],
+ as_ref=True),
l1=self._options['symmetric_l1_regularization'],
- l2=self._options['symmetric_l2_regularization'],
+ l2=self._symmetric_l2_regularization(),
+ # TODO(rohananil): Provide empirical evidence for this. It is better
+ # to run more than one iteration on single mini-batch as we want to
+ # spend more time in compute. SDCA works better with larger
+ # mini-batches and there is also recent work that shows its better to
+ # reuse old samples than train on new samples.
+ # See: http://arxiv.org/abs/1602.02136.
num_inner_iterations=2,
loss_type=self._options['loss_type'],
container=self._container,
solver_uuid=self._solver_uuid)
with ops.control_dependencies([step_op]):
- assign_ops = control_flow_ops.group(*self._assign_ops)
- with ops.control_dependencies([assign_ops]):
+ assign_ops = []
+ for name in ['sparse_features_weights', 'dense_features_weights']:
+ for var, slot_var in zip(self._variables[name],
+ self._slots['unshrinked_' + name]):
+ assign_ops.append(var.assign(slot_var))
+ assign_group = control_flow_ops.group(*assign_ops)
+ with ops.control_dependencies([assign_group]):
return _sdca_ops.sdca_shrink_l1(
self._convert_n_to_tensor(
self._variables['sparse_features_weights'],
@@ -298,7 +304,7 @@ class SdcaModel(object):
self._variables['dense_features_weights'],
as_ref=True),
l1=self._options['symmetric_l1_regularization'],
- l2=self._options['symmetric_l2_regularization'])
+ l2=self._symmetric_l2_regularization())
def approximate_duality_gap(self):
"""Add operations to compute the approximate duality gap.
@@ -307,15 +313,14 @@ class SdcaModel(object):
An Operation that computes the approximate duality gap over all
examples.
"""
- return _sdca_ops.compute_duality_gap(
- self._convert_n_to_tensor(self._slots['sparse_features_weights'],
- as_ref=True),
- self._convert_n_to_tensor(self._slots['dense_features_weights'],
- as_ref=True),
- l1=self._options['symmetric_l1_regularization'],
- l2=self._options['symmetric_l2_regularization'],
+ (primal_loss, dual_loss, example_weights) = _sdca_ops.sdca_training_stats(
container=self._container,
solver_uuid=self._solver_uuid)
+ # Note that example_weights is guaranteed to be positive by
+ # sdca_training_stats so dividing by it is safe.
+ return (primal_loss + dual_loss + math_ops.to_double(self._l1_loss()) +
+ (2.0 * math_ops.to_double(self._l2_loss(
+ self._symmetric_l2_regularization())))) / example_weights
def unregularized_loss(self, examples):
"""Add operations to compute the loss (without the regularization loss).
@@ -384,6 +389,11 @@ class SdcaModel(object):
self._assertList(['sparse_features', 'dense_features'], examples)
with name_scope('sdca/regularized_loss'):
weights = convert_to_tensor(examples['example_weights'])
- return ((
- (self._l1_loss() + self._l2_loss()) / math_ops.reduce_sum(weights)) +
+ return (((
+ self._l1_loss() +
+ # Note that here we are using the raw regularization
+ # (as specified by the user) and *not*
+ # self._symmetric_l2_regularization().
+ self._l2_loss(self._options['symmetric_l2_regularization'])) /
+ math_ops.reduce_sum(weights)) +
self.unregularized_loss(examples))