aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-30 01:19:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 01:23:04 -0700
commitc73d4e56eb2ac66e8fb519cbe83c5f7bddbfc80a (patch)
tree734b735c85564e14bee1171c7b891787654cbe67
parentd004f08ee6102b2081a4ae14420e9eb76b1eb669 (diff)
Changing DNNLinearCombinedClassifier to be "re-entrant" with respect to variable_scope. So that it doesn't assume it is called in the "root" scope of variables.
PiperOrigin-RevId: 210866643
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 62a1adf78c..9799cf9e98 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -161,8 +161,8 @@ def _dnn_linear_combined_model_fn(features,
with variable_scope.variable_scope(
dnn_parent_scope,
values=tuple(six.itervalues(features)),
- partitioner=dnn_partitioner):
-
+ partitioner=dnn_partitioner) as scope:
+ dnn_absolute_scope = scope.name
dnn_logit_fn = dnn._dnn_logit_fn_builder( # pylint: disable=protected-access
units=head.logits_dimension,
hidden_units=dnn_hidden_units,
@@ -186,6 +186,7 @@ def _dnn_linear_combined_model_fn(features,
linear_parent_scope,
values=tuple(six.itervalues(features)),
partitioner=input_layer_partitioner) as scope:
+ linear_absolute_scope = scope.name
logit_fn = linear._linear_logit_fn_builder( # pylint: disable=protected-access
units=head.logits_dimension,
feature_columns=linear_feature_columns,
@@ -211,14 +212,14 @@ def _dnn_linear_combined_model_fn(features,
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=dnn_parent_scope)))
+ scope=dnn_absolute_scope)))
if linear_logits is not None:
train_ops.append(
linear_optimizer.minimize(
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=linear_parent_scope)))
+ scope=linear_absolute_scope)))
train_op = control_flow_ops.group(*train_ops)
with ops.control_dependencies([train_op]):