aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py')
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py101
1 files changed, 82 insertions, 19 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 1d2db1cec8..8466dc36d1 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -125,7 +125,7 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
],
example_ids=[str(i) for i in range(num_examples)])
- weights = variables_lib.Variable(
+ weights = variables_lib.VariableV1(
array_ops.zeros([dim], dtype=dtypes.float32))
variables_dict = dict(
sparse_features_weights=[weights],
@@ -134,7 +134,7 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
return examples_dict, variables_dict
-def make_variable_dict(max_age, max_gender, partitioned=False):
+def make_variable_dict(max_age, max_gender, num_shards=None, partitioned=False):
# TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from
# examples_dict.
partitioner = None
@@ -142,14 +142,15 @@ def make_variable_dict(max_age, max_gender, partitioned=False):
partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2,
axis=0)
with variable_scope.variable_scope(
- name_or_scope='variables',
+ name_or_scope=('variables/shard_{}'.format(num_shards)
+ if num_shards else 'variables'),
partitioner=partitioner):
- age_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_age + 1], dtype=dtypes.float32))
- gender_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_gender + 1], dtype=dtypes.float32))
+ age_weights = variable_scope.get_variable(
+ name='age',
+ initializer=array_ops.zeros([max_age + 1], dtype=dtypes.float32))
+ gender_weights = variable_scope.get_variable(
+ name='gender',
+ initializer=array_ops.zeros([max_gender + 1], dtype=dtypes.float32))
return dict(
sparse_features_weights=[age_weights, gender_weights],
dense_features_weights=[])
@@ -183,7 +184,7 @@ def make_dense_examples_and_variables_dicts(dense_features_values, weights,
dense_tensors.append(dense_tensor)
# Add variables of shape [feature_column_dimension].
dense_weights.append(
- variables_lib.Variable(
+ variables_lib.VariableV1(
array_ops.zeros(
[dense_tensor.get_shape().as_list()[1]], dtype=dtypes.float32)))
@@ -242,7 +243,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -290,7 +291,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1, partitioned=True)
+ variables = make_variable_dict(1, 1, num_shards, partitioned=True)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -322,6 +323,68 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
self.assertAllClose(
0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+ def testSomePartitionedPrimals(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0],
+ 'gender': [0]
+ }, 0),
+ make_example_proto({
+ 'age': [0],
+ 'gender': [1]
+ }, 1),
+ ]
+ example_weights = [1.0, 1.0]
+ for num_shards in _SHARD_NUMBERS:
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ # Explicitly make age a [1]-shaped Variable (which cannot be
+ # partitioned), while making gender a PartitionedVariable.
+ age_weights = variables_lib.VariableV1(
+ array_ops.zeros([1], dtype=dtypes.float32))
+ with variable_scope.variable_scope(
+ name_or_scope=('variables/shard_{}'.format(num_shards)
+ if num_shards else 'variables'),
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0)):
+ gender_weights = variable_scope.get_variable(
+ name='gender',
+ initializer=array_ops.zeros([2], dtype=dtypes.float32))
+ variables = dict(
+ sparse_features_weights=[age_weights, gender_weights],
+ dense_features_weights=[])
+ options = dict(
+ symmetric_l2_regularization=1,
+ symmetric_l1_regularization=0,
+ num_table_shards=num_shards,
+ loss_type='logistic_loss')
+
+ lr = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+ unregularized_loss = lr.unregularized_loss(examples)
+ loss = lr.regularized_loss(examples)
+ predictions = lr.predictions(examples)
+ self.assertAllClose(0.693147, unregularized_loss.eval())
+ self.assertAllClose(0.693147, loss.eval())
+ train_op = lr.minimize()
+ for _ in range(_MAX_ITERATIONS):
+ train_op.run()
+ lr.update_weights(train_op).run()
+ # The high tolerance in unregularized_loss comparisons is due to the
+ # fact that it's possible to trade off unregularized_loss vs.
+ # regularization and still have a sum that is quite close to the
+ # optimal regularized_loss value. SDCA's duality gap only ensures that
+ # the regularized_loss is within 0.01 of optimal.
+ # 0.525457 is the optimal regularized_loss.
+ # 0.593014 is the unregularized_loss at that optimum.
+ self.assertAllClose(0.512591, unregularized_loss.eval(), atol=0.05)
+ self.assertAllClose(0.593014, loss.eval(), atol=0.01)
+ predicted_labels = get_binary_predictions_for_logistic(predictions)
+ self.assertAllEqual([0, 1], predicted_labels.eval())
+ self.assertAllClose(
+ 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+
def testSparseRandom(self):
dim = 20
num_examples = 1000
@@ -463,7 +526,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=0,
symmetric_l1_regularization=0,
@@ -521,7 +584,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
with self._single_threaded_test_session():
# Only use examples 0 and 2
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -561,7 +624,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -598,7 +661,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(3, 1)
+ variables = make_variable_dict(3, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -639,7 +702,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -679,7 +742,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -738,7 +801,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
labels=[1.0, 0.0])
# Replace with a variable of size 1 instead of 2.
variables['dense_features_weights'] = [
- variables_lib.Variable(array_ops.zeros(
+ variables_lib.VariableV1(array_ops.zeros(
[1], dtype=dtypes.float32))
]
options = dict(