aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2017-04-17 15:20:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-17 16:39:05 -0700
commitfd561221d2fe782d320b97346dfffb41f38d2bcf (patch)
tree814fb813ff9b93e2535a65ccaba679f6c4b267e4
parentae84106edc892b60976b1635907009888150989f (diff)
Refactor Keras initializers to rely on core TF initializers; add serialization methods to core TF initializers.
Change: 153403157
-rw-r--r--tensorflow/contrib/keras/BUILD2
-rw-r--r--tensorflow/contrib/keras/python/keras/initializers.py286
-rw-r--r--tensorflow/contrib/keras/python/keras/initializers_test.py115
-rw-r--r--tensorflow/contrib/keras/python/keras/integration_test.py8
-rw-r--r--tensorflow/contrib/keras/python/keras/optimizers_test.py2
-rw-r--r--tensorflow/python/ops/init_ops.py109
-rw-r--r--tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt8
14 files changed, 239 insertions, 349 deletions
diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD
index 438e2056c6..5166ba37a3 100644
--- a/tensorflow/contrib/keras/BUILD
+++ b/tensorflow/contrib/keras/BUILD
@@ -134,7 +134,7 @@ py_library(
py_test(
name = "integration_test",
- size = "small",
+ size = "medium",
srcs = ["python/keras/integration_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
diff --git a/tensorflow/contrib/keras/python/keras/initializers.py b/tensorflow/contrib/keras/python/keras/initializers.py
index f9cb35e171..b0b71e7cb4 100644
--- a/tensorflow/contrib/keras/python/keras/initializers.py
+++ b/tensorflow/contrib/keras/python/keras/initializers.py
@@ -18,247 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
-
import numpy as np
import six
-from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object
-from tensorflow.python.framework import tensor_shape
-
-
-class Initializer(object):
- """Initializer base class: all initializers inherit from this class.
- """
-
- def __call__(self, shape, dtype=None):
- raise NotImplementedError
-
- def get_config(self):
- return {}
-
- @classmethod
- def from_config(cls, config):
- return cls(**config)
-
-
-class Zeros(Initializer):
- """Initializer that generates tensors initialized to 0.
- """
-
- def __call__(self, shape, dtype=None):
- return K.constant(0, shape=shape, dtype=dtype)
-
-
-class Ones(Initializer):
- """Initializer that generates tensors initialized to 1.
- """
-
- def __call__(self, shape, dtype=None):
- return K.constant(1, shape=shape, dtype=dtype)
-
-
-class Constant(Initializer):
- """Initializer that generates tensors initialized to a constant value.
-
- Arguments:
- value: float; the value of the generator tensors.
- """
-
- def __init__(self, value=0):
- self.value = value
-
- def __call__(self, shape, dtype=None):
- return K.constant(self.value, shape=shape, dtype=dtype)
-
- def get_config(self):
- return {'value': self.value}
-
-
-class RandomNormal(Initializer):
- """Initializer that generates tensors with a normal distribution.
-
- Arguments:
- mean: a python scalar or a scalar tensor. Mean of the random values
- to generate.
- stddev: a python scalar or a scalar tensor. Standard deviation of the
- random values to generate.
- seed: A Python integer. Used to seed the random generator.
- """
-
- def __init__(self, mean=0., stddev=0.05, seed=None):
- self.mean = mean
- self.stddev = stddev
- self.seed = seed
-
- def __call__(self, shape, dtype=None):
- return K.random_normal(
- shape, self.mean, self.stddev, dtype=dtype, seed=self.seed)
-
- def get_config(self):
- return {'mean': self.mean, 'stddev': self.stddev, 'seed': self.seed}
-
-
-class RandomUniform(Initializer):
- """Initializer that generates tensors with a uniform distribution.
-
- Arguments:
- minval: A python scalar or a scalar tensor. Lower bound of the range
- of random values to generate.
- maxval: A python scalar or a scalar tensor. Upper bound of the range
- of random values to generate. Defaults to 1 for float types.
- seed: A Python integer. Used to seed the random generator.
- """
-
- def __init__(self, minval=-0.05, maxval=0.05, seed=None):
- self.minval = minval
- self.maxval = maxval
- self.seed = seed
-
- def __call__(self, shape, dtype=None):
- return K.random_uniform(
- shape, self.minval, self.maxval, dtype=dtype, seed=self.seed)
-
- def get_config(self):
- return {
- 'minval': self.minval,
- 'maxval': self.maxval,
- 'seed': self.seed,
- }
-
-
-class TruncatedNormal(Initializer):
- """Initializer that generates a truncated normal distribution.
-
- These values are similar to values from a `RandomNormal`
- except that values more than two standard deviations from the mean
- are discarded and re-drawn. This is the recommended initializer for
- neural network weights and filters.
-
- Arguments:
- mean: a python scalar or a scalar tensor. Mean of the random values
- to generate.
- stddev: a python scalar or a scalar tensor. Standard deviation of the
- random values to generate.
- seed: A Python integer. Used to seed the random generator.
- """
-
- def __init__(self, mean=0., stddev=0.05, seed=None):
- self.mean = mean
- self.stddev = stddev
- self.seed = seed
-
- def __call__(self, shape, dtype=None):
- return K.truncated_normal(
- shape, self.mean, self.stddev, dtype=dtype, seed=self.seed)
-
- def get_config(self):
- return {'mean': self.mean, 'stddev': self.stddev, 'seed': self.seed}
-
-
-class VarianceScaling(Initializer):
- """Initializer capable of adapting its scale to the shape of weights.
-
- With `distribution="normal"`, samples are drawn from a truncated normal
- distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
-
- - number of input units in the weight tensor, if mode = "fan_in"
- - number of output units, if mode = "fan_out"
- - average of the numbers of input and output units, if mode = "fan_avg"
-
- With `distribution="uniform"`,
- samples are drawn from a uniform distribution
- within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
-
- Arguments:
- scale: Scaling factor (positive float).
- mode: One of "fan_in", "fan_out", "fan_avg".
- distribution: Random distribution to use. One of "normal", "uniform".
- seed: A Python integer. Used to seed the random generator.
-
- Raises:
- ValueError: In case of an invalid value for the "scale", mode" or
- "distribution" arguments.
- """
-
- def __init__(self, scale=1.0, mode='fan_in', distribution='normal',
- seed=None):
- if scale <= 0.:
- raise ValueError('`scale` must be a positive float. Got:', scale)
- mode = mode.lower()
- if mode not in {'fan_in', 'fan_out', 'fan_avg'}:
- raise ValueError('Invalid `mode` argument: '
- 'expected on of {"fan_in", "fan_out", "fan_avg"} '
- 'but got', mode)
- distribution = distribution.lower()
- if distribution not in {'normal', 'uniform'}:
- raise ValueError('Invalid `distribution` argument: '
- 'expected one of {"normal", "uniform"} '
- 'but got', distribution)
- self.scale = scale
- self.mode = mode
- self.distribution = distribution
- self.seed = seed
-
- def __call__(self, shape, dtype=None):
- fan_in, fan_out = _compute_fans(shape)
- scale = self.scale
- if self.mode == 'fan_in':
- scale /= max(1., fan_in)
- elif self.mode == 'fan_out':
- scale /= max(1., fan_out)
- else:
- scale /= max(1., float(fan_in + fan_out) / 2)
- if self.distribution == 'normal':
- stddev = math.sqrt(scale)
- return K.truncated_normal(shape, 0., stddev, dtype=dtype, seed=self.seed)
- else:
- limit = math.sqrt(3. * scale)
- return K.random_uniform(shape, -limit, limit, dtype=dtype, seed=self.seed)
-
- def get_config(self):
- return {
- 'scale': self.scale,
- 'mode': self.mode,
- 'distribution': self.distribution,
- 'seed': self.seed
- }
-
-
-class Orthogonal(Initializer):
- """Initializer that generates a random orthogonal matrix.
-
- Arguments:
- gain: Multiplicative factor to apply to the orthogonal matrix.
- seed: A Python integer. Used to seed the random generator.
-
- References:
- Saxe et al., http://arxiv.org/abs/1312.6120
- """
-
- def __init__(self, gain=1., seed=None):
- self.gain = gain
- self.seed = seed
-
- def __call__(self, shape, dtype=None):
- num_rows = 1
- for dim in shape[:-1]:
- num_rows *= dim
- num_cols = shape[-1]
- flat_shape = (num_rows, num_cols)
- if self.seed is not None:
- np.random.seed(self.seed)
- a = np.random.normal(0.0, 1.0, flat_shape)
- u, _, v = np.linalg.svd(a, full_matrices=False)
- # Pick the one with the correct shape.
- q = u if u.shape == flat_shape else v
- q = q.reshape(shape)
- return self.gain * q[:shape[0], :shape[1]]
-
- def get_config(self):
- return {'gain': self.gain, 'seed': self.seed}
+from tensorflow.python.ops.init_ops import Constant
+from tensorflow.python.ops.init_ops import Initializer
+from tensorflow.python.ops.init_ops import Ones
+from tensorflow.python.ops.init_ops import Orthogonal
+from tensorflow.python.ops.init_ops import RandomNormal
+from tensorflow.python.ops.init_ops import RandomUniform
+from tensorflow.python.ops.init_ops import TruncatedNormal
+from tensorflow.python.ops.init_ops import VarianceScaling
+from tensorflow.python.ops.init_ops import Zeros
class Identity(Initializer):
@@ -406,47 +179,6 @@ orthogonal = Orthogonal
# Utility functions
-def _compute_fans(shape, data_format='channels_last'):
- """Computes the number of input and output units for a weight shape.
-
- Arguments:
- shape: Integer shape tuple.
- data_format: Image data format to use for convolution kernels.
- Note that all kernels in Keras are standardized on the
- `channels_last` ordering (even when inputs are set
- to `channels_first`).
-
- Returns:
- A tuple of scalars, `(fan_in, fan_out)`.
-
- Raises:
- ValueError: in case of invalid `data_format` argument.
- """
- shape = tensor_shape.TensorShape(shape).as_list()
- if len(shape) == 2:
- fan_in = shape[0]
- fan_out = shape[1]
- elif len(shape) in {3, 4, 5}:
- # Assuming convolution kernels (1D, 2D or 3D).
- # TH kernel shape: (depth, input_depth, ...)
- # TF kernel shape: (..., input_depth, depth)
- if data_format == 'channels_first':
- receptive_field_size = np.prod(shape[2:])
- fan_in = shape[1] * receptive_field_size
- fan_out = shape[0] * receptive_field_size
- elif data_format == 'channels_last':
- receptive_field_size = np.prod(shape[:2])
- fan_in = shape[-2] * receptive_field_size
- fan_out = shape[-1] * receptive_field_size
- else:
- raise ValueError('Invalid data_format: ' + data_format)
- else:
- # No specific assumptions.
- fan_in = math.sqrt(np.prod(shape))
- fan_out = math.sqrt(np.prod(shape))
- return fan_in, fan_out
-
-
def serialize(initializer):
return serialize_keras_object(initializer)
diff --git a/tensorflow/contrib/keras/python/keras/initializers_test.py b/tensorflow/contrib/keras/python/keras/initializers_test.py
index 7436fbb390..c9f50c28ea 100644
--- a/tensorflow/contrib/keras/python/keras/initializers_test.py
+++ b/tensorflow/contrib/keras/python/keras/initializers_test.py
@@ -21,121 +21,132 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.keras.python import keras
+from tensorflow.python.ops import init_ops
from tensorflow.python.platform import test
-def _runner(init, shape, target_mean=None, target_std=None,
- target_max=None, target_min=None):
- variable = keras.backend.variable(init(shape))
- output = keras.backend.get_value(variable)
- lim = 3e-2
- if target_std is not None:
- assert abs(output.std() - target_std) < lim, output.std()
- if target_mean is not None:
- assert abs(output.mean() - target_mean) < lim, output.mean()
- if target_max is not None:
- assert abs(output.max() - target_max) < lim, output.max()
- if target_min is not None:
- assert abs(output.min() - target_min) < lim, output.min()
-
-
class KerasInitializersTest(test.TestCase):
+ def _runner(self, init, shape, target_mean=None, target_std=None,
+ target_max=None, target_min=None):
+ variable = keras.backend.variable(init(shape))
+ output = keras.backend.get_value(variable)
+ lim = 3e-2
+ if target_std is not None:
+ self.assertGreater(lim, abs(output.std() - target_std))
+ if target_mean is not None:
+ self.assertGreater(lim, abs(output.mean() - target_mean))
+ if target_max is not None:
+ self.assertGreater(lim, abs(output.max() - target_max))
+ if target_min is not None:
+ self.assertGreater(lim, abs(output.min() - target_min))
+
+ # Test serialization (assumes deterministic behavior).
+ config = init.get_config()
+ reconstructed_init = init.__class__.from_config(config)
+ variable = keras.backend.variable(reconstructed_init(shape))
+ output_2 = keras.backend.get_value(variable)
+ self.assertAllClose(output, output_2, atol=1e-4)
+
def test_uniform(self):
tensor_shape = (9, 6, 7)
with self.test_session():
- _runner(keras.initializers.RandomUniform(minval=-1, maxval=1, seed=124),
- tensor_shape,
- target_mean=0., target_max=1, target_min=-1)
+ self._runner(keras.initializers.RandomUniform(minval=-1,
+ maxval=1,
+ seed=124),
+ tensor_shape,
+ target_mean=0., target_max=1, target_min=-1)
def test_normal(self):
tensor_shape = (8, 12, 99)
with self.test_session():
- _runner(keras.initializers.RandomNormal(mean=0, stddev=1, seed=153),
- tensor_shape,
- target_mean=0., target_std=1)
+ self._runner(keras.initializers.RandomNormal(mean=0, stddev=1, seed=153),
+ tensor_shape,
+ target_mean=0., target_std=1)
def test_truncated_normal(self):
tensor_shape = (12, 99, 7)
with self.test_session():
- _runner(keras.initializers.TruncatedNormal(mean=0, stddev=1, seed=126),
- tensor_shape,
- target_mean=0., target_std=None, target_max=2)
+ self._runner(keras.initializers.TruncatedNormal(mean=0,
+ stddev=1,
+ seed=126),
+ tensor_shape,
+ target_mean=0., target_std=None, target_max=2)
def test_constant(self):
tensor_shape = (5, 6, 4)
with self.test_session():
- _runner(keras.initializers.Constant(2), tensor_shape,
- target_mean=2, target_max=2, target_min=2)
+ self._runner(keras.initializers.Constant(2), tensor_shape,
+ target_mean=2, target_max=2, target_min=2)
def test_lecun_uniform(self):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
- fan_in, _ = keras.initializers._compute_fans(tensor_shape)
+ fan_in, _ = init_ops._compute_fans(tensor_shape)
scale = np.sqrt(3. / fan_in)
- _runner(keras.initializers.lecun_uniform(seed=123), tensor_shape,
- target_mean=0., target_max=scale, target_min=-scale)
+ self._runner(keras.initializers.lecun_uniform(seed=123), tensor_shape,
+ target_mean=0., target_max=scale, target_min=-scale)
def test_glorot_uniform(self):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
- fan_in, fan_out = keras.initializers._compute_fans(tensor_shape)
+ fan_in, fan_out = init_ops._compute_fans(tensor_shape)
scale = np.sqrt(6. / (fan_in + fan_out))
- _runner(keras.initializers.glorot_uniform(seed=123), tensor_shape,
- target_mean=0., target_max=scale, target_min=-scale)
+ self._runner(keras.initializers.glorot_uniform(seed=123), tensor_shape,
+ target_mean=0., target_max=scale, target_min=-scale)
def test_he_uniform(self):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
- fan_in, _ = keras.initializers._compute_fans(tensor_shape)
+ fan_in, _ = init_ops._compute_fans(tensor_shape)
scale = np.sqrt(6. / fan_in)
- _runner(keras.initializers.he_uniform(seed=123), tensor_shape,
- target_mean=0., target_max=scale, target_min=-scale)
+ self._runner(keras.initializers.he_uniform(seed=123), tensor_shape,
+ target_mean=0., target_max=scale, target_min=-scale)
def test_glorot_normal(self):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
- fan_in, fan_out = keras.initializers._compute_fans(tensor_shape)
+ fan_in, fan_out = init_ops._compute_fans(tensor_shape)
scale = np.sqrt(2. / (fan_in + fan_out))
- _runner(keras.initializers.glorot_normal(seed=123), tensor_shape,
- target_mean=0., target_std=None, target_max=2 * scale)
+ self._runner(keras.initializers.glorot_normal(seed=123), tensor_shape,
+ target_mean=0., target_std=None, target_max=2 * scale)
def test_he_normal(self):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
- fan_in, _ = keras.initializers._compute_fans(tensor_shape)
+ fan_in, _ = init_ops._compute_fans(tensor_shape)
scale = np.sqrt(2. / fan_in)
- _runner(keras.initializers.he_normal(seed=123), tensor_shape,
- target_mean=0., target_std=None, target_max=2 * scale)
+ self._runner(keras.initializers.he_normal(seed=123), tensor_shape,
+ target_mean=0., target_std=None, target_max=2 * scale)
def test_orthogonal(self):
- tensor_shape = (7, 8)
+ tensor_shape = (10, 10)
with self.test_session():
- _runner(keras.initializers.orthogonal(seed=123), tensor_shape,
- target_mean=0.)
+ self._runner(keras.initializers.orthogonal(seed=123), tensor_shape,
+ target_mean=0.)
def test_identity(self):
with self.test_session():
tensor_shape = (3, 4, 5)
with self.assertRaises(ValueError):
- _runner(keras.initializers.identity(), tensor_shape,
- target_mean=1. / tensor_shape[0], target_max=1.)
+ self._runner(keras.initializers.identity(), tensor_shape,
+ target_mean=1. / tensor_shape[0], target_max=1.)
tensor_shape = (3, 3)
- _runner(keras.initializers.identity(), tensor_shape,
- target_mean=1. / tensor_shape[0], target_max=1.)
+ self._runner(keras.initializers.identity(), tensor_shape,
+ target_mean=1. / tensor_shape[0], target_max=1.)
def test_zero(self):
tensor_shape = (4, 5)
with self.test_session():
- _runner(keras.initializers.zeros(), tensor_shape,
- target_mean=0., target_max=0.)
+ self._runner(keras.initializers.zeros(), tensor_shape,
+ target_mean=0., target_max=0.)
def test_one(self):
tensor_shape = (4, 5)
with self.test_session():
- _runner(keras.initializers.ones(), tensor_shape,
- target_mean=1., target_max=1.)
+ self._runner(keras.initializers.ones(), tensor_shape,
+ target_mean=1., target_max=1.)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/keras/python/keras/integration_test.py b/tensorflow/contrib/keras/python/keras/integration_test.py
index 3a3d36ca1c..16d0713b31 100644
--- a/tensorflow/contrib/keras/python/keras/integration_test.py
+++ b/tensorflow/contrib/keras/python/keras/integration_test.py
@@ -33,13 +33,13 @@ class KerasIntegrationTest(test.TestCase):
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=200,
test_samples=100,
- input_shape=(8,),
+ input_shape=(10,),
num_classes=2)
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
model = keras.models.Sequential([
- keras.layers.Dense(8,
+ keras.layers.Dense(16,
activation='relu',
input_shape=x_train.shape[1:]),
keras.layers.Dropout(0.1),
@@ -59,13 +59,13 @@ class KerasIntegrationTest(test.TestCase):
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=200,
test_samples=100,
- input_shape=(8,),
+ input_shape=(10,),
num_classes=2)
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
inputs = keras.layers.Input(shape=x_train.shape[1:])
- x = keras.layers.Dense(8, activation='relu')(inputs)
+ x = keras.layers.Dense(16, activation='relu')(inputs)
x = keras.layers.Dropout(0.1)(x)
outputs = keras.layers.Dense(y_train.shape[-1], activation='softmax')(x)
diff --git a/tensorflow/contrib/keras/python/keras/optimizers_test.py b/tensorflow/contrib/keras/python/keras/optimizers_test.py
index b3aaddb7c0..af5e3c99b9 100644
--- a/tensorflow/contrib/keras/python/keras/optimizers_test.py
+++ b/tensorflow/contrib/keras/python/keras/optimizers_test.py
@@ -41,7 +41,7 @@ def _test_optimizer(optimizer, target=0.75):
input_shape=(10,),
num_classes=2)
y_train = keras.utils.to_categorical(y_train)
- model = _get_model(x_train.shape[1], 10, y_train.shape[1])
+ model = _get_model(x_train.shape[1], 20, y_train.shape[1])
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 60175965da..67fff9c803 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -49,30 +49,65 @@ class Initializer(object):
def __call__(self, shape, dtype=None, partition_info=None):
raise NotImplementedError
+ def get_config(self):
+ """Returns the configuration of the initializer as a JSON-serializable dict.
+
+ Returns:
+ A JSON-serializable Python dict.
+ """
+ return {}
+
+ @classmethod
+ def from_config(cls, config):
+ """Instantiates an initializer from a configuration dictionary.
+
+ Example:
+
+ ```
+ initializer = RandomUniform(-1, 1)
+ config = initializer.get_config()
+ initializer = RandomUniform.from_config(config)
+ ```
+
+ Arguments:
+ config: A Python dictionary.
+ It will typically be the output of `get_config`.
+
+ Returns:
+ An Initializer instance.
+ """
+ return cls(**config)
+
class Zeros(Initializer):
"""Initializer that generates tensors initialized to 0."""
def __init__(self, dtype=dtypes.float32):
- self.dtype = dtype
+ self.dtype = dtypes.as_dtype(dtype)
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
return array_ops.zeros(shape, dtype)
+ def get_config(self):
+ return {"dtype": self.dtype.name}
+
class Ones(Initializer):
"""Initializer that generates tensors initialized to 1."""
def __init__(self, dtype=dtypes.float32):
- self.dtype = dtype
+ self.dtype = dtypes.as_dtype(dtype)
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
return array_ops.ones(shape, dtype)
+ def get_config(self):
+ return {"dtype": self.dtype.name}
+
class Constant(Initializer):
"""Initializer that generates tensors with constant values.
@@ -151,14 +186,27 @@ class Constant(Initializer):
def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
self.value = value
- self.dtype = dtype
- self.verify_shape = verify_shape
+ self.dtype = dtypes.as_dtype(dtype)
+ self._verify_shape = verify_shape
- def __call__(self, shape, dtype=None, partition_info=None):
+ def __call__(self, shape,
+ dtype=None,
+ partition_info=None,
+ verify_shape=None):
if dtype is None:
dtype = self.dtype
+ if verify_shape is None:
+ verify_shape = self._verify_shape
return constant_op.constant(self.value, dtype=dtype, shape=shape,
- verify_shape=self.verify_shape)
+ verify_shape=verify_shape)
+
+ def get_config(self):
+ # We don't include `verify_shape` for compatibility with Keras.
+ # `verify_shape` should be passed as an argument to `__call__` rather
+ # than as a constructor argument: conceptually it isn't a property
+ # of the initializer.
+ return {"value": self.value,
+ "dtype": self.dtype.name}
class RandomUniform(Initializer):
@@ -179,7 +227,7 @@ class RandomUniform(Initializer):
self.minval = minval
self.maxval = maxval
self.seed = seed
- self.dtype = dtype
+ self.dtype = dtypes.as_dtype(dtype)
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
@@ -187,6 +235,12 @@ class RandomUniform(Initializer):
return random_ops.random_uniform(shape, self.minval, self.maxval,
dtype, seed=self.seed)
+ def get_config(self):
+ return {"minval": self.minval,
+ "maxval": self.maxval,
+ "seed": self.seed,
+ "dtype": self.dtype.name}
+
class RandomNormal(Initializer):
"""Initializer that generates tensors with a normal distribution.
@@ -206,7 +260,7 @@ class RandomNormal(Initializer):
self.mean = mean
self.stddev = stddev
self.seed = seed
- self.dtype = _assert_float_dtype(dtype)
+ self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
@@ -214,6 +268,12 @@ class RandomNormal(Initializer):
return random_ops.random_normal(shape, self.mean, self.stddev,
dtype, seed=self.seed)
+ def get_config(self):
+ return {"mean": self.mean,
+ "stddev": self.stddev,
+ "seed": self.seed,
+ "dtype": self.dtype.name}
+
class TruncatedNormal(Initializer):
"""Initializer that generates a truncated normal distribution.
@@ -238,7 +298,7 @@ class TruncatedNormal(Initializer):
self.mean = mean
self.stddev = stddev
self.seed = seed
- self.dtype = _assert_float_dtype(dtype)
+ self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
@@ -246,6 +306,12 @@ class TruncatedNormal(Initializer):
return random_ops.truncated_normal(shape, self.mean, self.stddev,
dtype, seed=self.seed)
+ def get_config(self):
+ return {"mean": self.mean,
+ "stddev": self.stddev,
+ "seed": self.seed,
+ "dtype": self.dtype.name}
+
class UniformUnitScaling(Initializer):
"""Initializer that generates tensors without scaling variance.
@@ -277,7 +343,7 @@ class UniformUnitScaling(Initializer):
def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32):
self.factor = factor
self.seed = seed
- self.dtype = _assert_float_dtype(dtype)
+ self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
@@ -298,6 +364,11 @@ class UniformUnitScaling(Initializer):
return random_ops.random_uniform(shape, -max_val, max_val,
dtype, seed=self.seed)
+ def get_config(self):
+ return {"factor": self.factor,
+ "seed": self.seed,
+ "dtype": self.dtype.name}
+
class VarianceScaling(Initializer):
"""Initializer capable of adapting its scale to the shape of weights tensors.
@@ -342,7 +413,7 @@ class VarianceScaling(Initializer):
self.mode = mode
self.distribution = distribution
self.seed = seed
- self.dtype = _assert_float_dtype(dtype)
+ self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
@@ -367,6 +438,13 @@ class VarianceScaling(Initializer):
return random_ops.random_uniform(shape, -limit, limit,
dtype, seed=self.seed)
+ def get_config(self):
+ return {"scale": self.scale,
+ "mode": self.mode,
+ "distribution": self.distribution,
+ "seed": self.seed,
+ "dtype": self.dtype.name}
+
class Orthogonal(Initializer):
"""Initializer that generates an orthogonal matrix.
@@ -388,9 +466,9 @@ class Orthogonal(Initializer):
for behavior.
"""
- def __init__(self, gain=1.0, dtype=dtypes.float32, seed=None):
+ def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
self.gain = gain
- self.dtype = _assert_float_dtype(dtype)
+ self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
self.seed = seed
def __call__(self, shape, dtype=None, partition_info=None):
@@ -421,6 +499,11 @@ class Orthogonal(Initializer):
q = array_ops.transpose(v)
return self.gain * array_ops.reshape(q, shape)
+ def get_config(self):
+ return {"gain": self.gain,
+ "seed": self.seed,
+ "dtype": self.dtype.name}
+
# Aliases.
diff --git a/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt
index d34bfe5147..00ec669b16 100644
--- a/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt
@@ -7,4 +7,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'float32\'>\", \'False\'], "
}
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt
index d84ddc6eb0..210b56242b 100644
--- a/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt
@@ -7,4 +7,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\"], "
}
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt
index c8e266e70c..13ec7454f4 100644
--- a/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt
@@ -5,6 +5,14 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'gain\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
+ argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt
index 70308bc601..5993fdeb9c 100644
--- a/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt
@@ -7,4 +7,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
}
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt
index 37bb1956e8..a434ed1599 100644
--- a/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt
@@ -7,4 +7,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
}
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt
index 7c48f4af07..c1e1c230a9 100644
--- a/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt
@@ -7,4 +7,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
}
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt
index 4558db619e..e1b18dc92f 100644
--- a/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt
@@ -7,4 +7,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'factor\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
}
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt
index 8313009a68..e229b02cee 100644
--- a/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt
@@ -7,4 +7,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\"], "
}
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
}