aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/losses/python/losses_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/losses/python/losses_impl.py')
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl.py16
1 files changed, 7 insertions, 9 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
index 1ba3a64167..d389748374 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
@@ -949,6 +949,11 @@ def cycle_consistency_loss(data_x,
* loss = (loss_x2x + loss_y2y) / 2
where `loss` is the final result.
+ For the L1-norm, we follow the original implementation:
+ https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua
+ we use L1-norm of pixel-wise error normalized by data size such that
+ `cycle_loss_weight` can be specified independent of image size.
+
See https://arxiv.org/abs/1703.10593 for more details.
Args:
@@ -965,19 +970,12 @@ def cycle_consistency_loss(data_x,
A scalar `Tensor` of cycle consistency loss.
"""
- def _partial_cycle_consistency_loss(data, reconstructed_data):
- # Following the original implementation
- # https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua
- # use L1-norm of pixel-wise error normalized by data size so that
- # `cycle_loss_weight` can be specified independent of image size.
- return math_ops.reduce_mean(math_ops.abs(data - reconstructed_data))
-
with ops.name_scope(
scope,
'cycle_consistency_loss',
values=[data_x, reconstructed_data_x, data_y, reconstructed_data_y]):
- loss_x2x = _partial_cycle_consistency_loss(data_x, reconstructed_data_x)
- loss_y2y = _partial_cycle_consistency_loss(data_y, reconstructed_data_y)
+ loss_x2x = losses.absolute_difference(data_x, reconstructed_data_x)
+ loss_y2y = losses.absolute_difference(data_y, reconstructed_data_y)
loss = (loss_x2x + loss_y2y) / 2.0
if add_summaries:
summary.scalar('cycle_consistency_loss_x2x', loss_x2x)