diff options
Diffstat (limited to 'tensorflow/contrib/gan/python/train.py')
-rw-r--r-- | tensorflow/contrib/gan/python/train.py | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index a32ddd7a06..5d0ac93aec 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -279,14 +279,16 @@ def acgan_model( generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: - (discriminator_gen_outputs, discriminator_gen_classification_logits - ) = _validate_acgan_discriminator_outputs( - discriminator_fn(generated_data, generator_inputs)) + with ops.name_scope(dis_scope.name+'/generated/'): + (discriminator_gen_outputs, discriminator_gen_classification_logits + ) = _validate_acgan_discriminator_outputs( + discriminator_fn(generated_data, generator_inputs)) with variable_scope.variable_scope(dis_scope, reuse=True): - real_data = ops.convert_to_tensor(real_data) - (discriminator_real_outputs, discriminator_real_classification_logits - ) = _validate_acgan_discriminator_outputs( - discriminator_fn(real_data, generator_inputs)) + with ops.name_scope(dis_scope.name+'/real/'): + real_data = ops.convert_to_tensor(real_data) + (discriminator_real_outputs, discriminator_real_classification_logits + ) = _validate_acgan_discriminator_outputs( + discriminator_fn(real_data, generator_inputs)) if check_shapes: if not generated_data.shape.is_compatible_with(real_data.shape): raise ValueError( |