aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Xuechen Li <lxuechen@google.com>2018-08-02 17:37:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 17:41:39 -0700
commit294766e34ddfa801878f58560d088b07841bea26 (patch)
tree43f9d34ca59db0a029f2f472383fdc51fd5fe2e5 /tensorflow/contrib/eager
parentbb3ed5ee461988f1020b9768a42ce27966ec08dc (diff)
Remove sagan example.
PiperOrigin-RevId: 207195679
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/BUILD2
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/BUILD59
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/config.py72
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/ops.py71
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/ops_test.py59
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/sagan.py232
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/sagan_test.py101
7 files changed, 0 insertions, 596 deletions
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index 12155a459c..6f02c90368 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -15,8 +15,6 @@ py_library(
"//tensorflow/contrib/eager/python/examples/revnet:config",
"//tensorflow/contrib/eager/python/examples/rnn_colorbot",
"//tensorflow/contrib/eager/python/examples/rnn_ptb",
- "//tensorflow/contrib/eager/python/examples/sagan",
- "//tensorflow/contrib/eager/python/examples/sagan:config",
"//tensorflow/contrib/eager/python/examples/spinn:data",
],
)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/BUILD b/tensorflow/contrib/eager/python/examples/sagan/BUILD
deleted file mode 100644
index b470a41d81..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/BUILD
+++ /dev/null
@@ -1,59 +0,0 @@
-licenses(["notice"]) # Apache 2.0
-
-package(default_visibility = ["//tensorflow:internal"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-# Model
-py_library(
- name = "config",
- srcs = ["config.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "ops",
- srcs = ["ops.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "sagan",
- srcs = ["sagan.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":ops",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-# Tests
-cuda_py_test(
- name = "ops_test",
- size = "small",
- srcs = ["ops_test.py"],
- additional_deps = [
- ":ops",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-cuda_py_test(
- name = "sagan_test",
- size = "large",
- srcs = ["sagan_test.py"],
- additional_deps = [
- ":config",
- ":sagan",
- "//tensorflow:tensorflow_py",
- ],
- tags = [
- "optonly",
- ],
-)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/config.py b/tensorflow/contrib/eager/python/examples/sagan/config.py
deleted file mode 100644
index 1967bbd867..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/config.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Self-attention generative adversarial with eager execution.
-
-Configuration in format of tf.contrib.training.HParams.
-Supports default 128x128 ImageNet.
-
-Reference [Self-Attention Generative Adversarial
-Networks](https://arxiv.org/pdf/1805.08318.pdf)
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-tfe = tf.contrib.eager
-
-
-def get_hparams_imagenet():
- """Configurations to train SAGAN on 128x128 ImageNet dataset."""
- config = tf.contrib.training.HParams()
- if tf.test.is_gpu_available():
- config.add_hparam("image_shape", (3, 128, 128))
- config.add_hparam("data_format", "channels_first")
- config.add_hparam("g_init_shape", (512, 4, 4))
- else:
- config.add_hparam("image_shape", (128, 128, 3))
- config.add_hparam("data_format", "channels_first")
- config.add_hparam("g_init_shape", (4, 4, 512))
-
- config.add_hparam("latent_dim", 128)
- config.add_hparam("update_g_once_every", 1)
- config.add_hparam("batch_size", 64)
- config.add_hparam("d_init_filters", 32)
- config.add_hparam("num_upsamples", 5)
- # (512, 4, 4) -> (3, 128, 128)
- return config
-
-
-def get_hparams_mock():
- """Configurations of smaller networks for testing."""
- config = tf.contrib.training.HParams()
- if tf.test.is_gpu_available():
- config.add_hparam("image_shape", (3, 16, 16))
- config.add_hparam("data_format", "channels_first")
- config.add_hparam("g_init_shape", (32, 2, 2))
- else:
- config.add_hparam("image_shape", (16, 16, 3))
- config.add_hparam("data_format", "channels_last")
- config.add_hparam("g_init_shape", (2, 2, 32))
-
- config.add_hparam("latent_dim", 16)
- config.add_hparam("update_g_once_every", 1)
- config.add_hparam("batch_size", 2)
- config.add_hparam("d_init_filters", 4)
- config.add_hparam("num_upsamples", 3)
- # (32, 2, 2) -> (3, 16, 16)
- return config
diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops.py b/tensorflow/contrib/eager/python/examples/sagan/ops.py
deleted file mode 100644
index 9a03cab1d1..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/ops.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Self-attention generative adversarial with eager execution.
-
-Auxiliary operations.
-
-Reference [Self-Attention Generative Adversarial
-Networks](https://arxiv.org/pdf/1805.08318.pdf)
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-
-
-def flatten_hw(x, data_format="channels_first"):
- """Flatten the input tensor across height and width dimensions."""
- if data_format == "channels_last":
- x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first`
-
- old_shape = tf.shape(x)
- new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]]
-
- return tf.reshape(x, new_shape)
-
-
-def broaden_hw(x, h, w, c, data_format="channels_first"):
- """Broaden dimension so that output has height and width."""
- if data_format == "channels_first":
- shape = [-1, c, h, w]
- else:
- shape = [-1, h, w, c]
-
- return tf.reshape(x, shape)
-
-
-class BroadenHW(tf.keras.layers.Layer):
- """Wrapper class so that `broaden_hw` can be used in `tf.keras.Sequential`."""
-
- def __init__(self, h, w, c, data_format="channels_first"):
- super(BroadenHW, self).__init__()
- self.h = h
- self.w = w
- self.c = c
- self.data_format = data_format
-
- def call(self, x):
- return broaden_hw(
- x, h=self.h, w=self.w, c=self.c, data_format=self.data_format)
-
- def compute_output_shape(self, input_shape):
- input_shape = tf.TensorShape(input_shape).as_list()
- if self.data_format == "channels_first":
- output_shape = (input_shape[0], self.c, self.h, self.w)
- else:
- output_shape = (input_shape[0], self.h, self.w, self.c)
-
- return tf.TensorShape(output_shape)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py
deleted file mode 100644
index 3454985904..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for auxiliary operations."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-from tensorflow.contrib.eager.python.examples.sagan import ops
-
-
-class OpsTest(tf.test.TestCase):
-
- def test_flatten_hw(self):
- """Test `flatten_hw` function with mock object."""
-
- batch_size = 1
- # Default NCHW format
- if tf.test.is_gpu_available():
- x = tf.random_normal(shape=(batch_size, 3, 4, 4))
- y = ops.flatten_hw(x, data_format="channels_first")
- self.assertEqual(y.shape, (batch_size, 4 * 4, 3))
-
- # NHWC format
- x = tf.random_normal(shape=(batch_size, 4, 4, 3))
- y = ops.flatten_hw(x, data_format="channels_last")
- self.assertEqual(y.shape, (batch_size, 4 * 4, 3))
-
- def test_broaden_hw(self):
- """Test `broaden_hw` function with mock object."""
-
- batch_size = 1
- # NHWC format
- x = tf.random_normal(shape=[batch_size, 4 * 4 * 16])
- y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_last")
- self.assertEqual(y.shape, (batch_size, 4, 4, 16))
-
- # Default NCHW format
- if tf.test.is_gpu_available():
- y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_first")
- self.assertEqual(y.shape, (batch_size, 16, 4, 4))
-
-
-if __name__ == "__main__":
- tf.enable_eager_execution()
- tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py
deleted file mode 100644
index 8130414985..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/sagan.py
+++ /dev/null
@@ -1,232 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Self-attention generative adversarial with eager execution.
-
-Code for main model.
-
-Reference [Self-Attention Generative Adversarial
-Networks](https://arxiv.org/pdf/1805.08318.pdf)
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-from tensorflow.contrib.eager.python.examples.sagan import ops
-tfe = tf.contrib.eager
-
-
-class SelfAttentionModule(tf.keras.Model):
- """Self-attention module composed of convolutional layers."""
-
- def __init__(self,
- attention_features,
- original_features,
- data_format="channels_first"):
- """Initialize the module.
-
- Args:
- attention_features: Number of filters for the attention computation.
- original_features: Number of filters of the original Tensor.
- data_format: Either 'channels_first' or 'channels_last'
- """
- super(SelfAttentionModule, self).__init__()
- self.data_format = data_format
- # Matrix multiplication implemented as 2D Convolution
- self.f = tf.keras.layers.Conv2D(
- filters=attention_features,
- kernel_size=1,
- strides=(1, 1),
- data_format=data_format)
- self.g = tf.keras.layers.Conv2D(
- filters=attention_features,
- kernel_size=1,
- strides=(1, 1),
- data_format=data_format)
- self.h = tf.keras.layers.Conv2D(
- filters=original_features,
- kernel_size=1,
- strides=(1, 1),
- data_format=data_format)
- self.scale = tf.Variable(0., trainable=True)
-
- def call(self, x):
- f = self.f(x)
- g = self.g(x)
- h = self.h(x)
-
- f_flatten = ops.flatten_hw(f, data_format=self.data_format)
- g_flatten = ops.flatten_hw(g, data_format=self.data_format)
- h_flatten = ops.flatten_hw(h, data_format=self.data_format)
-
- s = tf.matmul(g_flatten, f_flatten, transpose_b=True)
- b = tf.nn.softmax(s, axis=-1)
- o = tf.matmul(b, h_flatten)
- y = self.scale * tf.reshape(o, tf.shape(x)) + x
-
- return y
-
- def compute_output_shape(self, input_shape):
- return input_shape
-
-
-class SAGAN(tf.contrib.checkpoint.Checkpointable):
- """Self-attention generative adversarial network."""
-
- def __init__(self, config):
- """Initialize the model.
-
- Args:
- config: tf.contrib.training.HParams object; specifies hyperparameters
- """
- super(SAGAN, self).__init__()
- self.config = config
- self.generator = self._construct_generator()
- self.discriminator = self._construct_discriminator()
-
- def _construct_generator(self):
- """Construct generator."""
- # TODO(lxuechen): Add spectral normalization for WGAN
- axis = 1 if self.config.data_format == "channels_first" else 3
-
- generator = tf.keras.Sequential()
- generator.add(
- tf.keras.layers.InputLayer(input_shape=(self.config.latent_dim,)))
- generator.add(
- tf.keras.layers.Dense(
- units=np.prod(self.config.g_init_shape), activation=tf.nn.relu))
-
- if self.config.data_format == "channels_first":
- c, h, w = self.config.g_init_shape
- else:
- h, w, c = self.config.g_init_shape
-
- # Reshape to NHWC/NCHW
- generator.add(
- ops.BroadenHW(h=h, w=w, c=c, data_format=self.config.data_format))
-
- filters_list = [c // 2**p for p in range(1, self.config.num_upsamples + 1)]
- filters_list[-1] = 3 # Standard RGB images
-
- for filters in filters_list[:len(filters_list) // 2]:
- generator.add(
- tf.keras.layers.Conv2DTranspose(
- filters=filters,
- kernel_size=4,
- strides=(2, 2),
- use_bias=False,
- padding="SAME",
- data_format=self.config.data_format))
- generator.add(tf.keras.layers.BatchNormalization(axis=axis))
- generator.add(tf.keras.layers.Activation("relu"))
-
- # pylint: disable=undefined-loop-variable
- generator.add(
- SelfAttentionModule(
- original_features=filters,
- attention_features=filters // 8,
- data_format=self.config.data_format))
- # pylint: enable=undefined-loop-variable
-
- for filters in filters_list[len(filters_list) // 2:]:
- generator.add(
- tf.keras.layers.Conv2DTranspose(
- filters=filters,
- kernel_size=4,
- strides=(2, 2),
- use_bias=False,
- padding="SAME",
- data_format=self.config.data_format))
- if filters == 3:
- # Assume Image rescaled to [-1, 1]
- generator.add(tf.keras.layers.Activation("tanh"))
- else:
- generator.add(tf.keras.layers.BatchNormalization(axis=axis))
- generator.add(tf.keras.layers.Activation("relu"))
-
- return generator
-
- def _construct_discriminator(self):
- """Construct discriminator."""
- # TODO(lxuechen): Add spectral normalization for WGAN
- discriminator = tf.keras.Sequential()
- discriminator.add(
- tf.keras.layers.InputLayer(input_shape=self.config.image_shape))
-
- filters_list = [
- self.config.d_init_filters * 2**p
- for p in range(self.config.num_upsamples)
- ]
-
- for filters in filters_list[:(len(filters_list) + 1) // 2]:
- discriminator.add(
- tf.keras.layers.Conv2D(
- filters=filters,
- kernel_size=4,
- strides=(2, 2),
- padding="SAME",
- data_format=self.config.data_format))
- discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1))
-
- # pylint: disable=undefined-loop-variable
- discriminator.add(
- SelfAttentionModule(
- original_features=filters,
- attention_features=filters // 8,
- data_format=self.config.data_format))
- # pylint: enable=undefined-loop-variable
-
- for filters in filters_list[(len(filters_list) + 1) // 2:]:
- discriminator.add(
- tf.keras.layers.Conv2D(
- filters=filters,
- kernel_size=4,
- strides=(2, 2),
- padding="SAME",
- data_format=self.config.data_format))
- discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1))
-
- discriminator.add(tf.keras.layers.Flatten())
- discriminator.add(tf.keras.layers.Dense(units=1))
-
- return discriminator
-
- def compute_loss_and_grads(self, real_images, noise, training=True):
- """Compute loss and gradients for both generator and discriminator."""
- # TODO(lxuechen): Add gradient penalty for discriminator
- with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
- real_logits = self.discriminator(real_images, training=training)
-
- fake_images = self.generator.call(noise, training=training)
- fake_logits = self.discriminator.call(fake_images)
-
- g_loss = self.compute_g_loss(fake_logits)
- d_loss = self.compute_d_loss(fake_logits, real_logits)
-
- g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables)
- d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables)
-
- return g_loss, d_loss, g_grads, d_grads
-
- def compute_g_loss(self, fake_logits):
- return -tf.reduce_mean(fake_logits) # Hinge loss
-
- def compute_d_loss(self, fake_logits, real_logits):
- # Hinge loss
- real_loss = tf.reduce_mean(tf.nn.relu(1. - real_logits))
- fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake_logits))
- return real_loss + fake_loss
diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py
deleted file mode 100644
index 1834594510..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for self-attention generative adversarial network."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-from tensorflow.contrib.eager.python.examples.sagan import config as config_
-from tensorflow.contrib.eager.python.examples.sagan import sagan
-tfe = tf.contrib.eager
-
-
-class SAGANTest(tf.test.TestCase):
-
- def setUp(self):
- super(SAGANTest, self).setUp()
- config = config_.get_hparams_mock()
- self.noise_shape = (config.batch_size, config.latent_dim)
- self.logits_shape = (config.batch_size, 1)
- self.images_shape = (config.batch_size,) + config.image_shape
-
- self.model = sagan.SAGAN(config=config)
- self.noise = tf.random_normal(shape=self.noise_shape)
- self.real_images = tf.random_normal(shape=self.images_shape)
- self.config = config
-
- def tearDown(self):
- del self.model
- del self.noise
- del self.real_images
- super(SAGANTest, self).tearDown()
-
- def test_generator_call(self):
- """Test `generator.__call__` function."""
- fake_images = self.model.generator(self.noise, training=False)
- self.assertEqual(fake_images.shape, self.images_shape)
-
- def test_generator_call_defun(self):
- """Test `generator.__call__` function with defun."""
- call_ = tfe.defun(self.model.generator.__call__)
- fake_images = call_(self.noise, training=False)
- self.assertEqual(fake_images.shape, self.images_shape)
-
- def test_discriminator_call(self):
- """Test `discriminator.__call__` function."""
- real_logits = self.model.discriminator(self.real_images)
- self.assertEqual(real_logits.shape, self.logits_shape)
-
- def test_discriminator_call_defun(self):
- """Test `discriminator.__call__` function with defun."""
- call_ = tfe.defun(self.model.discriminator.__call__)
- real_logits = call_(self.real_images)
- self.assertEqual(real_logits.shape, self.logits_shape)
-
- def test_compute_loss_and_grads(self):
- """Test `compute_loss_and_grads` function."""
- g_loss, d_loss, g_grads, d_grads = self.model.compute_loss_and_grads(
- self.real_images, self.noise, training=False)
- self.assertEqual(g_loss.shape, ())
- self.assertEqual(d_loss.shape, ())
- self.assertTrue(isinstance(g_grads, list))
- self.assertTrue(isinstance(d_grads, list))
- g_vars = self.model.generator.trainable_variables
- d_vars = self.model.discriminator.trainable_variables
-
- self.assertEqual(len(g_grads), len(g_vars))
- self.assertEqual(len(d_grads), len(d_vars))
-
- def test_compute_loss_and_grads_defun(self):
- """Test `compute_loss_and_grads` function with defun."""
- compute_loss_and_grads = tfe.defun(self.model.compute_loss_and_grads)
- g_loss, d_loss, g_grads, d_grads = compute_loss_and_grads(
- self.real_images, self.noise, training=False)
- self.assertEqual(g_loss.shape, ())
- self.assertEqual(d_loss.shape, ())
- self.assertTrue(isinstance(g_grads, list))
- self.assertTrue(isinstance(d_grads, list))
- g_vars = self.model.generator.trainable_variables
- d_vars = self.model.discriminator.trainable_variables
-
- self.assertEqual(len(g_grads), len(g_vars))
- self.assertEqual(len(d_grads), len(d_vars))
-
-
-if __name__ == "__main__":
- tf.enable_eager_execution()
- tf.test.main()