diff options
author | Rohan Jain <rohanj@google.com> | 2018-06-19 17:26:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-19 17:29:25 -0700 |
commit | 1f48db29a4a0cf7e0017ad6aa3bb1f8f7ee8ff92 (patch) | |
tree | 6fea3447777aa79bfb62bafb7896f71a1a64d0e6 /tensorflow/python/feature_column | |
parent | 841031362630230c5e3bcb6915a842087619ec12 (diff) |
Fixing a bug in linear_model where the name for the model is always set to 'linear_model'. This causes issues when we create multiple linear models in the same graph.
PiperOrigin-RevId: 201270816
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 4 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_test.py | 29 |
2 files changed, 28 insertions, 5 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 5ae60028f4..40219e4b34 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -452,13 +452,15 @@ def linear_model(features, ValueError: if an item in `feature_columns` is neither a `_DenseColumn` nor `_CategoricalColumn`. """ + with variable_scope.variable_scope(None, 'linear_model') as vs: + model_name = _strip_leading_slashes(vs.name) linear_model_layer = _LinearModel( feature_columns=feature_columns, units=units, sparse_combiner=sparse_combiner, weight_collections=weight_collections, trainable=trainable, - name='linear_model') + name=model_name) retval = linear_model_layer(features) # pylint: disable=not-callable if cols_to_vars is not None: cols_to_vars.update(linear_model_layer.cols_to_vars()) diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index c80c1d1866..dc3dde6710 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -1257,14 +1257,14 @@ class CrossedColumnTest(test.TestCase): }, (crossed,)) -def get_linear_model_bias(): - with variable_scope.variable_scope('linear_model', reuse=True): +def get_linear_model_bias(name='linear_model'): + with variable_scope.variable_scope(name, reuse=True): return variable_scope.get_variable('bias_weights') -def get_linear_model_column_var(column): +def get_linear_model_column_var(column, name='linear_model'): return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, - 'linear_model/' + column.name)[0] + name + '/' + column.name)[0] def get_keras_linear_model_predictions(features, @@ -1928,6 +1928,27 @@ class LinearModelTest(test.TestCase): with self.assertRaisesOpError('Feature .* cannot have rank 0'): sess.run(net, feed_dict={features['price']: np.array(1)}) + def test_multiple_linear_models(self): + price = fc.numeric_column('price') + with ops.Graph().as_default(): + features1 = {'price': [[1.], [5.]]} + features2 = {'price': [[2.], [10.]]} + predictions1 = fc.linear_model(features1, [price]) + predictions2 = fc.linear_model(features2, [price]) + bias1 = get_linear_model_bias(name='linear_model') + bias2 = get_linear_model_bias(name='linear_model_1') + price_var1 = get_linear_model_column_var(price, name='linear_model') + price_var2 = get_linear_model_column_var(price, name='linear_model_1') + with _initialized_session() as sess: + self.assertAllClose([0.], bias1.eval()) + sess.run(price_var1.assign([[10.]])) + sess.run(bias1.assign([5.])) + self.assertAllClose([[15.], [55.]], predictions1.eval()) + self.assertAllClose([0.], bias2.eval()) + sess.run(price_var2.assign([[10.]])) + sess.run(bias2.assign([5.])) + self.assertAllClose([[25.], [105.]], predictions2.eval()) + class _LinearModelTest(test.TestCase): |