aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-03-23 23:04:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 04:59:27 -0700
commitf95347a96c431b63183856128bfea3943585f938 (patch)
tree319cd0b392c41eefe0f09483d4de4287282cf194
parent917b79250b0e65aa7856b2418b68292d919cd5dc (diff)
Trivial update of layer imports in eager execution examples, to reflect recommended practices.
PiperOrigin-RevId: 190319480
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/mnist.py21
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py4
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50.py43
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py6
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py8
5 files changed, 46 insertions, 36 deletions
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py
index 2b7e199fad..b80c909023 100644
--- a/tensorflow/contrib/eager/python/examples/gan/mnist.py
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py
@@ -32,6 +32,7 @@ import tensorflow as tf
import tensorflow.contrib.eager as tfe
from tensorflow.examples.tutorials.mnist import input_data
+layers = tf.keras.layers
FLAGS = None
@@ -56,15 +57,15 @@ class Discriminator(tf.keras.Model):
else:
assert data_format == 'channels_last'
self._input_shape = [-1, 28, 28, 1]
- self.conv1 = tf.layers.Conv2D(
+ self.conv1 = layers.Conv2D(
64, 5, padding='SAME', data_format=data_format, activation=tf.tanh)
- self.pool1 = tf.layers.AveragePooling2D(2, 2, data_format=data_format)
- self.conv2 = tf.layers.Conv2D(
+ self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format)
+ self.conv2 = layers.Conv2D(
128, 5, data_format=data_format, activation=tf.tanh)
- self.pool2 = tf.layers.AveragePooling2D(2, 2, data_format=data_format)
- self.flatten = tf.layers.Flatten()
- self.fc1 = tf.layers.Dense(1024, activation=tf.tanh)
- self.fc2 = tf.layers.Dense(1, activation=None)
+ self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format)
+ self.flatten = layers.Flatten()
+ self.fc1 = layers.Dense(1024, activation=tf.tanh)
+ self.fc2 = layers.Dense(1, activation=None)
def call(self, inputs):
"""Return two logits per image estimating input authenticity.
@@ -112,16 +113,16 @@ class Generator(tf.keras.Model):
else:
assert data_format == 'channels_last'
self._pre_conv_shape = [-1, 6, 6, 128]
- self.fc1 = tf.layers.Dense(6 * 6 * 128, activation=tf.tanh)
+ self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh)
# In call(), we reshape the output of fc1 to _pre_conv_shape
# Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
- self.conv1 = tf.layers.Conv2DTranspose(
+ self.conv1 = layers.Conv2DTranspose(
64, 4, strides=2, activation=None, data_format=data_format)
# Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
- self.conv2 = tf.layers.Conv2DTranspose(
+ self.conv2 = layers.Conv2DTranspose(
1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)
def call(self, inputs):
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
index 6ab847cb78..4e1380afb2 100644
--- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
@@ -32,6 +32,8 @@ import tensorflow as tf
import tensorflow.contrib.eager as tfe
+layers = tf.keras.layers
+
class LinearModel(tf.keras.Model):
"""A TensorFlow linear regression model."""
@@ -39,7 +41,7 @@ class LinearModel(tf.keras.Model):
def __init__(self):
"""Constructs a LinearModel object."""
super(LinearModel, self).__init__()
- self._hidden_layer = tf.layers.Dense(1)
+ self._hidden_layer = layers.Dense(1)
def call(self, xs):
"""Invoke the linear model.
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
index 6b59413141..a28bc8a43d 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
@@ -28,6 +28,8 @@ import functools
import tensorflow as tf
+layers = tf.keras.layers
+
class _IdentityBlock(tf.keras.Model):
"""_IdentityBlock is the block that has no conv layer at shortcut.
@@ -49,23 +51,23 @@ class _IdentityBlock(tf.keras.Model):
bn_name_base = 'bn' + str(stage) + block + '_branch'
bn_axis = 1 if data_format == 'channels_first' else 3
- self.conv2a = tf.layers.Conv2D(
+ self.conv2a = layers.Conv2D(
filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format)
- self.bn2a = tf.layers.BatchNormalization(
+ self.bn2a = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2a')
- self.conv2b = tf.layers.Conv2D(
+ self.conv2b = layers.Conv2D(
filters2,
kernel_size,
padding='same',
data_format=data_format,
name=conv_name_base + '2b')
- self.bn2b = tf.layers.BatchNormalization(
+ self.bn2b = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2b')
- self.conv2c = tf.layers.Conv2D(
+ self.conv2c = layers.Conv2D(
filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
- self.bn2c = tf.layers.BatchNormalization(
+ self.bn2c = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2c')
def call(self, input_tensor, training=False):
@@ -113,34 +115,34 @@ class _ConvBlock(tf.keras.Model):
bn_name_base = 'bn' + str(stage) + block + '_branch'
bn_axis = 1 if data_format == 'channels_first' else 3
- self.conv2a = tf.layers.Conv2D(
+ self.conv2a = layers.Conv2D(
filters1, (1, 1),
strides=strides,
name=conv_name_base + '2a',
data_format=data_format)
- self.bn2a = tf.layers.BatchNormalization(
+ self.bn2a = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2a')
- self.conv2b = tf.layers.Conv2D(
+ self.conv2b = layers.Conv2D(
filters2,
kernel_size,
padding='same',
name=conv_name_base + '2b',
data_format=data_format)
- self.bn2b = tf.layers.BatchNormalization(
+ self.bn2b = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2b')
- self.conv2c = tf.layers.Conv2D(
+ self.conv2c = layers.Conv2D(
filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
- self.bn2c = tf.layers.BatchNormalization(
+ self.bn2c = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2c')
- self.conv_shortcut = tf.layers.Conv2D(
+ self.conv_shortcut = layers.Conv2D(
filters3, (1, 1),
strides=strides,
name=conv_name_base + '1',
data_format=data_format)
- self.bn_shortcut = tf.layers.BatchNormalization(
+ self.bn_shortcut = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '1')
def call(self, input_tensor, training=False):
@@ -219,15 +221,15 @@ class ResNet50(tf.keras.Model):
return _IdentityBlock(
3, filters, stage=stage, block=block, data_format=data_format)
- self.conv1 = tf.layers.Conv2D(
+ self.conv1 = layers.Conv2D(
64, (7, 7),
strides=(2, 2),
data_format=data_format,
padding='same',
name='conv1')
bn_axis = 1 if data_format == 'channels_first' else 3
- self.bn_conv1 = tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1')
- self.max_pool = tf.layers.MaxPooling2D(
+ self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')
+ self.max_pool = layers.MaxPooling2D(
(3, 3), strides=(2, 2), data_format=data_format)
self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1))
@@ -250,11 +252,12 @@ class ResNet50(tf.keras.Model):
self.l5b = id_block([512, 512, 2048], stage=5, block='b')
self.l5c = id_block([512, 512, 2048], stage=5, block='c')
- self.avg_pool = tf.layers.AveragePooling2D(
+ self.avg_pool = layers.AveragePooling2D(
(7, 7), strides=(7, 7), data_format=data_format)
if self.include_top:
- self.fc1000 = tf.layers.Dense(classes, name='fc1000')
+ self.flatten = layers.Flatten()
+ self.fc1000 = layers.Dense(classes, name='fc1000')
else:
reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3]
reduction_indices = tf.constant(reduction_indices)
@@ -298,7 +301,7 @@ class ResNet50(tf.keras.Model):
x = self.avg_pool(x)
if self.include_top:
- return self.fc1000(tf.layers.flatten(x))
+ return self.fc1000(self.flatten(x))
elif self.global_pooling:
return self.global_pooling(x)
else:
diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
index 88fffc962f..492adbe1d8 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
@@ -73,6 +73,8 @@ try:
except ImportError:
HAS_MATPLOTLIB = False
+layers = tf.keras.layers
+
def parse(line):
"""Parse a line from the colors dataset."""
@@ -152,7 +154,7 @@ class RNNColorbot(tf.keras.Model):
self.cells = self._add_cells(
[tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes])
- self.relu = tf.layers.Dense(
+ self.relu = layers.Dense(
label_dimension, activation=tf.nn.relu, name="relu")
def call(self, inputs, training=False):
@@ -204,7 +206,7 @@ class RNNColorbot(tf.keras.Model):
def _add_cells(self, cells):
# "Magic" required for keras.Model classes to track all the variables in
- # a list of tf.layers.Layer objects.
+ # a list of layers.Layer objects.
# TODO(ashankar): Figure out API so user code doesn't have to do this.
for i, c in enumerate(cells):
setattr(self, "cell-%d" % i, c)
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index 69cd16d12c..a90048d813 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -38,6 +38,8 @@ import tensorflow as tf
from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
from tensorflow.contrib.eager.python import tfe
+layers = tf.keras.layers
+
class RNN(tf.keras.Model):
"""A static RNN.
@@ -74,14 +76,14 @@ class RNN(tf.keras.Model):
def _add_cells(self, cells):
# "Magic" required for keras.Model classes to track all the variables in
- # a list of tf.layers.Layer objects.
+ # a list of Layer objects.
# TODO(ashankar): Figure out API so user code doesn't have to do this.
for i, c in enumerate(cells):
setattr(self, "cell-%d" % i, c)
return cells
-class Embedding(tf.layers.Layer):
+class Embedding(layers.Layer):
"""An Embedding layer."""
def __init__(self, vocab_size, embedding_dim, **kwargs):
@@ -132,7 +134,7 @@ class PTBModel(tf.keras.Model):
else:
self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio)
- self.linear = tf.layers.Dense(
+ self.linear = layers.Dense(
vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))
self._output_shape = [-1, embedding_dim]