diff options
author | 2018-08-30 01:19:12 -0700 | |
---|---|---|
committer | 2018-08-30 01:23:04 -0700 | |
commit | c73d4e56eb2ac66e8fb519cbe83c5f7bddbfc80a (patch) | |
tree | 734b735c85564e14bee1171c7b891787654cbe67 | |
parent | d004f08ee6102b2081a4ae14420e9eb76b1eb669 (diff) |
Changing DNNLinearCombinedClassifier to be "re-entrant" with respect to variable_scope. So that it doesn't assume it is called in the "root" scope of variables.
PiperOrigin-RevId: 210866643
-rw-r--r-- | tensorflow/python/estimator/canned/dnn_linear_combined.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 62a1adf78c..9799cf9e98 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -161,8 +161,8 @@ def _dnn_linear_combined_model_fn(features, with variable_scope.variable_scope( dnn_parent_scope, values=tuple(six.itervalues(features)), - partitioner=dnn_partitioner): - + partitioner=dnn_partitioner) as scope: + dnn_absolute_scope = scope.name dnn_logit_fn = dnn._dnn_logit_fn_builder( # pylint: disable=protected-access units=head.logits_dimension, hidden_units=dnn_hidden_units, @@ -186,6 +186,7 @@ def _dnn_linear_combined_model_fn(features, linear_parent_scope, values=tuple(six.itervalues(features)), partitioner=input_layer_partitioner) as scope: + linear_absolute_scope = scope.name logit_fn = linear._linear_logit_fn_builder( # pylint: disable=protected-access units=head.logits_dimension, feature_columns=linear_feature_columns, @@ -211,14 +212,14 @@ def _dnn_linear_combined_model_fn(features, loss, var_list=ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES, - scope=dnn_parent_scope))) + scope=dnn_absolute_scope))) if linear_logits is not None: train_ops.append( linear_optimizer.minimize( loss, var_list=ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES, - scope=linear_parent_scope))) + scope=linear_absolute_scope))) train_op = control_flow_ops.group(*train_ops) with ops.control_dependencies([train_op]): |