aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer
diff options
context:
space:
mode:
authorGravatar Eddie Zhou <eddz@google.com>2018-09-17 15:06:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 15:11:26 -0700
commit55581a5bed7108c2d39ab603db8c916b6d624648 (patch)
tree74e6dcfdec0579db6b71181190504888290dbd6b /tensorflow/contrib/linear_optimizer
parent2015dc15784e635c40f256ed9f3b9b0b3539daaf (diff)
Fix testing bug where partitioned primals wasn't actually being tested (constructing Variable directly instead of get_variable under scope with partitioner).
PiperOrigin-RevId: 213345447
Diffstat (limited to 'tensorflow/contrib/linear_optimizer')
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py33
1 files changed, 17 insertions, 16 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 1d2db1cec8..7a1914d41f 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -134,7 +134,7 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
return examples_dict, variables_dict
-def make_variable_dict(max_age, max_gender, partitioned=False):
+def make_variable_dict(max_age, max_gender, num_shards=None, partitioned=False):
# TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from
# examples_dict.
partitioner = None
@@ -142,14 +142,15 @@ def make_variable_dict(max_age, max_gender, partitioned=False):
partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2,
axis=0)
with variable_scope.variable_scope(
- name_or_scope='variables',
+ name_or_scope=('variables/shard_{}'.format(num_shards)
+ if num_shards else 'variables'),
partitioner=partitioner):
- age_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_age + 1], dtype=dtypes.float32))
- gender_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_gender + 1], dtype=dtypes.float32))
+ age_weights = variable_scope.get_variable(
+ name='age',
+ initializer=array_ops.zeros([max_age + 1], dtype=dtypes.float32))
+ gender_weights = variable_scope.get_variable(
+ name='gender',
+ initializer=array_ops.zeros([max_gender + 1], dtype=dtypes.float32))
return dict(
sparse_features_weights=[age_weights, gender_weights],
dense_features_weights=[])
@@ -242,7 +243,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -290,7 +291,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1, partitioned=True)
+ variables = make_variable_dict(1, 1, num_shards, partitioned=True)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -463,7 +464,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=0,
symmetric_l1_regularization=0,
@@ -521,7 +522,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
with self._single_threaded_test_session():
# Only use examples 0 and 2
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -561,7 +562,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -598,7 +599,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(3, 1)
+ variables = make_variable_dict(3, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -639,7 +640,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -679,7 +680,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,