aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-23 15:46:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 15:48:49 -0700
commit78c3a8870d2f748f356415e8d7acf9748d09c197 (patch)
treec4c0eab4e79e14e1ddfa3e9b77de42ea1b2c96f8 /tensorflow/contrib/learn
parentf504a2445051c4c48eb9edd6a023b1f33a2793f2 (diff)
Add support for partitioned variables to SDCA.
PiperOrigin-RevId: 197803127
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py112
2 files changed, 116 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index 70b70af98c..e100bc7a1e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -31,7 +31,6 @@ import six
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
-from tensorflow.python.training import training_util
from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
@@ -51,6 +50,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training as train
+from tensorflow.python.training import training_util
# The default learning rate of 0.2 is a historical artifact of the initial
@@ -244,7 +244,9 @@ def sdca_model_fn(features, labels, mode, params):
parent_scope = "linear"
with variable_scope.variable_scope(
- values=features.values(), name_or_scope=parent_scope) as scope:
+ values=features.values(),
+ name_or_scope=parent_scope,
+ partitioner=optimizer.partitioner) as scope:
features = features.copy()
features.update(layers.transform_features(features, feature_columns))
logits, columns_to_variables, bias = (
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index 0a863f0e20..597ca4e86d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -43,6 +43,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import test
from tensorflow.python.training import ftrl
from tensorflow.python.training import input as input_lib
@@ -966,6 +967,63 @@ class LinearClassifierTest(test.TestCase):
scores = classifier.evaluate(input_fn=input_fn, steps=1)
self.assertGreater(scores['accuracy'], 0.9)
+ def testSdcaOptimizerPartitionedVariables(self):
+ """Tests LinearClassifier with SDCAOptimizer with partitioned variables."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.6], [0.8], [0.3]]),
+ 'sq_footage':
+ constant_op.constant([[900.0], [700.0], [600.0]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [1.0], [1.0]])
+ }, constant_op.constant([[1], [0], [1]])
+
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+
+ sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
+ example_id_column='example_id',
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0))
+
+ tf_config = {
+ 'cluster': {
+ run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
+ }
+ }
+ with test.mock.patch.dict('os.environ',
+ {'TF_CONFIG': json.dumps(tf_config)}):
+ config = run_config.RunConfig()
+ # Because we did not start a distributed cluster, we need to pass an
+ # empty ClusterSpec, otherwise the device_setter will look for
+ # distributed jobs, such as "/job:ps" which are not present.
+ config._cluster_spec = server_lib.ClusterSpec({})
+
+ classifier = linear.LinearClassifier(
+ feature_columns=[price, sq_footage_bucket, country, sq_footage_country],
+ weight_column_name='weights',
+ optimizer=sdca_optimizer,
+ config=config)
+ classifier.fit(input_fn=input_fn, steps=50)
+ scores = classifier.evaluate(input_fn=input_fn, steps=1)
+ print('all scores = {}'.format(scores))
+ self.assertGreater(scores['accuracy'], 0.9)
+
def testEval(self):
"""Tests that eval produces correct metrics.
"""
@@ -1540,6 +1598,60 @@ class LinearRegressorTest(test.TestCase):
loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
self.assertLess(loss, 0.05)
+ def testSdcaOptimizerPartitionedVariables(self):
+ """Tests LinearRegressor with SDCAOptimizer with partitioned variables."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([0.6, 0.8, 0.3]),
+ 'sq_footage':
+ constant_op.constant([[900.0], [700.0], [600.0]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [5.0], [7.0]])
+ }, constant_op.constant([[1.55], [-1.25], [-3.0]])
+
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+ sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
+ example_id_column='example_id', symmetric_l2_regularization=1.0,
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0))
+ tf_config = {
+ 'cluster': {
+ run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
+ }
+ }
+ with test.mock.patch.dict('os.environ',
+ {'TF_CONFIG': json.dumps(tf_config)}):
+ config = run_config.RunConfig()
+ # Because we did not start a distributed cluster, we need to pass an
+ # empty ClusterSpec, otherwise the device_setter will look for
+ # distributed jobs, such as "/job:ps" which are not present.
+ config._cluster_spec = server_lib.ClusterSpec({})
+
+ regressor = linear.LinearRegressor(
+ feature_columns=[price, sq_footage_bucket, country, sq_footage_country],
+ weight_column_name='weights',
+ optimizer=sdca_optimizer,
+ config=config)
+ regressor.fit(input_fn=input_fn, steps=20)
+ loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
+ self.assertLess(loss, 0.05)
+
def testSdcaOptimizerSparseFeaturesWithL1Reg(self):
"""Tests LinearClassifier with SDCAOptimizer and sparse features."""