aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer
diff options
context:
space:
mode:
authorGravatar Eddie Zhou <eddz@google.com>2018-09-17 17:06:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 17:09:54 -0700
commit928389d4d61f0cb5932672aeeafadb1c18514dd3 (patch)
tree60cd4556096a02eb92bff8cef005534f541d343a /tensorflow/contrib/linear_optimizer
parent6e8293f1cdf2efe3cec2efdcfa89174893b0bace (diff)
Fixed bug where a mixture of Variable and PartitionedVariable would break SDCA. Added new test that fails with `IndexError: list index out of range` in `_get_partitioned_update_ops` without the corresponding fix.
Note that the effect of this bug is minimal, because for Estimator users, it only applies to sparse features that are not partitionable (e.g. [1,]), since all variables are created with the same partitioner in Estimator). PiperOrigin-RevId: 213365956
Diffstat (limited to 'tensorflow/contrib/linear_optimizer')
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py62
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py26
2 files changed, 76 insertions, 12 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 7a1914d41f..9ecf023e03 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -323,6 +323,68 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
self.assertAllClose(
0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+ def testSomePartitionedPrimals(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0],
+ 'gender': [0]
+ }, 0),
+ make_example_proto({
+ 'age': [0],
+ 'gender': [1]
+ }, 1),
+ ]
+ example_weights = [1.0, 1.0]
+ for num_shards in _SHARD_NUMBERS:
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ # Explicitly make age a [1]-shaped Variable (which cannot be
+ # partitioned), while making gender a PartitionedVariable.
+ age_weights = variables_lib.Variable(
+ array_ops.zeros([1], dtype=dtypes.float32))
+ with variable_scope.variable_scope(
+ name_or_scope=('variables/shard_{}'.format(num_shards)
+ if num_shards else 'variables'),
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0)):
+ gender_weights = variable_scope.get_variable(
+ name='gender',
+ initializer=array_ops.zeros([2], dtype=dtypes.float32))
+ variables = dict(
+ sparse_features_weights=[age_weights, gender_weights],
+ dense_features_weights=[])
+ options = dict(
+ symmetric_l2_regularization=1,
+ symmetric_l1_regularization=0,
+ num_table_shards=num_shards,
+ loss_type='logistic_loss')
+
+ lr = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+ unregularized_loss = lr.unregularized_loss(examples)
+ loss = lr.regularized_loss(examples)
+ predictions = lr.predictions(examples)
+ self.assertAllClose(0.693147, unregularized_loss.eval())
+ self.assertAllClose(0.693147, loss.eval())
+ train_op = lr.minimize()
+ for _ in range(_MAX_ITERATIONS):
+ train_op.run()
+ lr.update_weights(train_op).run()
+ # The high tolerance in unregularized_loss comparisons is due to the
+ # fact that it's possible to trade off unregularized_loss vs.
+ # regularization and still have a sum that is quite close to the
+ # optimal regularized_loss value. SDCA's duality gap only ensures that
+ # the regularized_loss is within 0.01 of optimal.
+ # 0.525457 is the optimal regularized_loss.
+ # 0.593014 is the unregularized_loss at that optimum.
+ self.assertAllClose(0.512591, unregularized_loss.eval(), atol=0.05)
+ self.assertAllClose(0.593014, loss.eval(), atol=0.01)
+ predicted_labels = get_binary_predictions_for_logistic(predictions)
+ self.assertAllEqual([0, 1], predicted_labels.eval())
+ self.assertAllClose(
+ 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+
def testSparseRandom(self):
dim = 20
num_examples = 1000
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 14f59a3f64..b98adf862b 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -400,14 +400,16 @@ class SdcaModel(object):
sparse_weights = []
sparse_indices = []
- # If we have partitioned variables, keep a few lists of Tensors around
- # that we need for the assign_add after the op call to
- # gen_sdca_ops.sdca_optimizer().
- num_partitions_by_var = []
- p_assignments_by_var = []
- gather_ids_by_var = []
- for w, i in zip(self._slots['unshrinked_sparse_features_weights'],
- sparse_feature_indices):
+ # If we have partitioned variables, keep a few dictionaries of Tensors
+ # around that we need for the assign_add after the op call to
+ # gen_sdca_ops.sdca_optimizer(). These are keyed because we may have a
+ # mix of partitioned and un-partitioned variables.
+ num_partitions_by_var = {}
+ p_assignments_by_var = {}
+ gather_ids_by_var = {}
+ for v_num, (w, i) in enumerate(
+ zip(self._slots['unshrinked_sparse_features_weights'],
+ sparse_feature_indices)):
# Append the sparse_indices (in full-variable space).
sparse_idx = math_ops.cast(
array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
@@ -456,10 +458,10 @@ class SdcaModel(object):
gather_ids = data_flow_ops.dynamic_partition(new_ids,
p_assignments,
num_partitions)
- # Append these to the lists for use in the later update.
- num_partitions_by_var.append(num_partitions)
- p_assignments_by_var.append(p_assignments)
- gather_ids_by_var.append(gather_ids)
+ # Add these into the dictionaries for use in the later update.
+ num_partitions_by_var[v_num] = num_partitions
+ p_assignments_by_var[v_num] = p_assignments
+ gather_ids_by_var[v_num] = gather_ids
# Gather the weights from each partition.
partition_gathered_weights = []