aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/train.py')
-rw-r--r--tensorflow/contrib/gan/python/train.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index a32ddd7a06..5d0ac93aec 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -279,14 +279,16 @@ def acgan_model(
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
- (discriminator_gen_outputs, discriminator_gen_classification_logits
- ) = _validate_acgan_discriminator_outputs(
- discriminator_fn(generated_data, generator_inputs))
+ with ops.name_scope(dis_scope.name+'/generated/'):
+ (discriminator_gen_outputs, discriminator_gen_classification_logits
+ ) = _validate_acgan_discriminator_outputs(
+ discriminator_fn(generated_data, generator_inputs))
with variable_scope.variable_scope(dis_scope, reuse=True):
- real_data = ops.convert_to_tensor(real_data)
- (discriminator_real_outputs, discriminator_real_classification_logits
- ) = _validate_acgan_discriminator_outputs(
- discriminator_fn(real_data, generator_inputs))
+ with ops.name_scope(dis_scope.name+'/real/'):
+ real_data = ops.convert_to_tensor(real_data)
+ (discriminator_real_outputs, discriminator_real_classification_logits
+ ) = _validate_acgan_discriminator_outputs(
+ discriminator_fn(real_data, generator_inputs))
if check_shapes:
if not generated_data.shape.is_compatible_with(real_data.shape):
raise ValueError(