diff options
author | Eddie Zhou <eddz@google.com> | 2018-09-17 15:06:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 15:11:26 -0700 |
commit | 55581a5bed7108c2d39ab603db8c916b6d624648 (patch) | |
tree | 74e6dcfdec0579db6b71181190504888290dbd6b /tensorflow/contrib/linear_optimizer | |
parent | 2015dc15784e635c40f256ed9f3b9b0b3539daaf (diff) |
Fix testing bug where partitioned primals wasn't actually being tested (constructing Variable directly instead of get_variable under scope with partitioner).
PiperOrigin-RevId: 213345447
Diffstat (limited to 'tensorflow/contrib/linear_optimizer')
-rw-r--r-- | tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py | 33 |
1 files changed, 17 insertions, 16 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 1d2db1cec8..7a1914d41f 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -134,7 +134,7 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero): return examples_dict, variables_dict -def make_variable_dict(max_age, max_gender, partitioned=False): +def make_variable_dict(max_age, max_gender, num_shards=None, partitioned=False): # TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from # examples_dict. partitioner = None @@ -142,14 +142,15 @@ def make_variable_dict(max_age, max_gender, partitioned=False): partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2, axis=0) with variable_scope.variable_scope( - name_or_scope='variables', + name_or_scope=('variables/shard_{}'.format(num_shards) + if num_shards else 'variables'), partitioner=partitioner): - age_weights = variables_lib.Variable( - array_ops.zeros( - [max_age + 1], dtype=dtypes.float32)) - gender_weights = variables_lib.Variable( - array_ops.zeros( - [max_gender + 1], dtype=dtypes.float32)) + age_weights = variable_scope.get_variable( + name='age', + initializer=array_ops.zeros([max_age + 1], dtype=dtypes.float32)) + gender_weights = variable_scope.get_variable( + name='gender', + initializer=array_ops.zeros([max_gender + 1], dtype=dtypes.float32)) return dict( sparse_features_weights=[age_weights, gender_weights], dense_features_weights=[]) @@ -242,7 +243,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -290,7 +291,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1, partitioned=True) + variables = make_variable_dict(1, 1, num_shards, partitioned=True) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -463,7 +464,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=0, symmetric_l1_regularization=0, @@ -521,7 +522,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): with self._single_threaded_test_session(): # Only use examples 0 and 2 examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -561,7 +562,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -598,7 +599,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(3, 1) + variables = make_variable_dict(3, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -639,7 +640,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -679,7 +680,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, |