aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-06-19 17:26:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 17:29:25 -0700
commit1f48db29a4a0cf7e0017ad6aa3bb1f8f7ee8ff92 (patch)
tree6fea3447777aa79bfb62bafb7896f71a1a64d0e6 /tensorflow/python/feature_column
parent841031362630230c5e3bcb6915a842087619ec12 (diff)
Fixing a bug in linear_model where the name for the model is always set to 'linear_model'. This causes issues when we create multiple linear models in the same graph.
PiperOrigin-RevId: 201270816
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py4
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py29
2 files changed, 28 insertions, 5 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 5ae60028f4..40219e4b34 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -452,13 +452,15 @@ def linear_model(features,
ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
nor `_CategoricalColumn`.
"""
+ with variable_scope.variable_scope(None, 'linear_model') as vs:
+ model_name = _strip_leading_slashes(vs.name)
linear_model_layer = _LinearModel(
feature_columns=feature_columns,
units=units,
sparse_combiner=sparse_combiner,
weight_collections=weight_collections,
trainable=trainable,
- name='linear_model')
+ name=model_name)
retval = linear_model_layer(features) # pylint: disable=not-callable
if cols_to_vars is not None:
cols_to_vars.update(linear_model_layer.cols_to_vars())
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index c80c1d1866..dc3dde6710 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -1257,14 +1257,14 @@ class CrossedColumnTest(test.TestCase):
}, (crossed,))
-def get_linear_model_bias():
- with variable_scope.variable_scope('linear_model', reuse=True):
+def get_linear_model_bias(name='linear_model'):
+ with variable_scope.variable_scope(name, reuse=True):
return variable_scope.get_variable('bias_weights')
-def get_linear_model_column_var(column):
+def get_linear_model_column_var(column, name='linear_model'):
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
- 'linear_model/' + column.name)[0]
+ name + '/' + column.name)[0]
def get_keras_linear_model_predictions(features,
@@ -1928,6 +1928,27 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
sess.run(net, feed_dict={features['price']: np.array(1)})
+ def test_multiple_linear_models(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features1 = {'price': [[1.], [5.]]}
+ features2 = {'price': [[2.], [10.]]}
+ predictions1 = fc.linear_model(features1, [price])
+ predictions2 = fc.linear_model(features2, [price])
+ bias1 = get_linear_model_bias(name='linear_model')
+ bias2 = get_linear_model_bias(name='linear_model_1')
+ price_var1 = get_linear_model_column_var(price, name='linear_model')
+ price_var2 = get_linear_model_column_var(price, name='linear_model_1')
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias1.eval())
+ sess.run(price_var1.assign([[10.]]))
+ sess.run(bias1.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions1.eval())
+ self.assertAllClose([0.], bias2.eval())
+ sess.run(price_var2.assign([[10.]]))
+ sess.run(bias2.assign([5.]))
+ self.assertAllClose([[25.], [105.]], predictions2.eval())
+
class _LinearModelTest(test.TestCase):