aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-04-19 15:32:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 15:36:00 -0700
commit4868ddd508a567a497935378956e9da18976f152 (patch)
treea19bf7589e8cce77408accf212923d053b9fff26 /tensorflow/python/feature_column
parent4bcf49c4b22205fc829f89da96e37f366c9fa9e6 (diff)
Simplifying cols_to_vars update
PiperOrigin-RevId: 193585237
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py6
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py6
2 files changed, 4 insertions, 8 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 87a52f8441..a7c4eabcb2 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -417,10 +417,8 @@ def linear_model(features,
trainable=trainable,
name='linear_model')
retval = linear_model_layer(features) # pylint: disable=not-callable
- if cols_to_vars is None:
- return retval
- for k, v in linear_model_layer.cols_to_vars().items():
- cols_to_vars[k] = v
+ if cols_to_vars is not None:
+ cols_to_vars.update(linear_model_layer.cols_to_vars())
return retval
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 49e06b8245..d963dd9b55 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -1269,10 +1269,8 @@ def get_keras_linear_model_predictions(features,
trainable,
name='linear_model')
retval = keras_linear_model(features) # pylint: disable=not-callable
- if cols_to_vars is None:
- return retval
- for k, v in keras_linear_model.cols_to_vars().items():
- cols_to_vars[k] = v
+ if cols_to_vars is not None:
+ cols_to_vars.update(keras_linear_model.cols_to_vars())
return retval