diff options
author | Francois Chollet <fchollet@google.com> | 2018-04-10 13:49:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-10 13:51:54 -0700 |
commit | 693b339ab2f062ec5bbb29f976c5d1fd94fbffa5 (patch) | |
tree | 1e11b6becc6b156b6e89ebb3f4d5d2f886bed188 /tensorflow/python/feature_column | |
parent | 6b593d329005ffb1a10b1c9cd1374d2cdb620b21 (diff) |
Refactor layers:
- tf.layers layers now subclasses tf.keras.layers layers.
- tf.keras.layers is now agnostic to variable scopes and global collections (future-proof). It also uses ResourceVariable everywhere by default.
- As a result tf.keras.layers is in general lower-complexity, with fewer hacks and workarounds. However some of current code is temporary (variable creation should be moved to Checkpointable, arguably, and there are some dependency issues that will require later refactors).
- The legacy tf.layers layers behavior is kept, with references to variable scopes and global collections injected in the subclassed tf.layers.base.Layer class (the content of tf.layers.base.Layer is the complexity differential between the old implementation and the new one).
Note: this refactor does slightly change the behavior of tf.layers.base.Layer, by disabling extreme edge-case behavior that either has long been invalid, or is dangerous and should most definitely be disabled. This will not affect any users since such behaviors only existed in the base Layer unit tests. The behaviors disabled are:
- Option to create reusable variables in `call` (already invalid for some time).
- Option to use a variable scope to create layer variables outside of the layer while not having the layer track such variables locally.
PiperOrigin-RevId: 192339798
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 35 |
1 files changed, 18 insertions, 17 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 3a315e5c2e..7a104fa4ac 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -581,24 +581,25 @@ class _LinearModel(training.Model): **kwargs) def call(self, features): - for column in self._feature_columns: - if not isinstance(column, (_DenseColumn, _CategoricalColumn)): - raise ValueError( - 'Items of feature_columns must be either a ' - '_DenseColumn or _CategoricalColumn. Given: {}'.format(column)) - weighted_sums = [] - ordered_columns = [] - builder = _LazyBuilder(features) - for layer in sorted(self._column_layers.values(), key=lambda x: x.name): - ordered_columns.append(layer._feature_column) # pylint: disable=protected-access - weighted_sum = layer(builder) - weighted_sums.append(weighted_sum) + with variable_scope.variable_scope(self.name): + for column in self._feature_columns: + if not isinstance(column, (_DenseColumn, _CategoricalColumn)): + raise ValueError( + 'Items of feature_columns must be either a ' + '_DenseColumn or _CategoricalColumn. Given: {}'.format(column)) + weighted_sums = [] + ordered_columns = [] + builder = _LazyBuilder(features) + for layer in sorted(self._column_layers.values(), key=lambda x: x.name): + ordered_columns.append(layer._feature_column) # pylint: disable=protected-access + weighted_sum = layer(builder) + weighted_sums.append(weighted_sum) - _verify_static_batch_size_equality(weighted_sums, ordered_columns) - predictions_no_bias = math_ops.add_n( - weighted_sums, name='weighted_sum_no_bias') - predictions = nn_ops.bias_add( - predictions_no_bias, self._bias_layer(builder), name='weighted_sum') # pylint: disable=not-callable + _verify_static_batch_size_equality(weighted_sums, ordered_columns) + predictions_no_bias = math_ops.add_n( + weighted_sums, name='weighted_sum_no_bias') + predictions = nn_ops.bias_add( + predictions_no_bias, self._bias_layer(builder), name='weighted_sum') # pylint: disable=not-callable return predictions def _add_layers(self, layers): |