aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-28 09:33:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-28 09:37:51 -0700
commit7635e9db104a095918891c3be6f5d53669f075af (patch)
tree3e505fade3b47b2955fb6d9b3d940568a6ab4be7
parent046f912bb5e18fae87cf96ec04d035dddde633ad (diff)
Handle num_rows < num_cols case in orthogonal_initializer properly (as the original SVD-based implementation would) instead of just padding with zeros.
Run pyformat. PiperOrigin-RevId: 163478869
-rw-r--r--tensorflow/python/ops/init_ops.py121
1 files changed, 58 insertions, 63 deletions
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 42b4f952bb..203a3cd485 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Operations often used for initializing tensors.
All variable initializers returned by functions in this file should have the
@@ -190,24 +189,20 @@ class Constant(Initializer):
self.dtype = dtypes.as_dtype(dtype)
self._verify_shape = verify_shape
- def __call__(self, shape,
- dtype=None,
- partition_info=None,
- verify_shape=None):
+ def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None):
if dtype is None:
dtype = self.dtype
if verify_shape is None:
verify_shape = self._verify_shape
- return constant_op.constant(self.value, dtype=dtype, shape=shape,
- verify_shape=verify_shape)
+ return constant_op.constant(
+ self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
def get_config(self):
# We don't include `verify_shape` for compatibility with Keras.
# `verify_shape` should be passed as an argument to `__call__` rather
# than as a constructor argument: conceptually it isn't a property
# of the initializer.
- return {"value": self.value,
- "dtype": self.dtype.name}
+ return {"value": self.value, "dtype": self.dtype.name}
class RandomUniform(Initializer):
@@ -233,14 +228,16 @@ class RandomUniform(Initializer):
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
- return random_ops.random_uniform(shape, self.minval, self.maxval,
- dtype, seed=self.seed)
+ return random_ops.random_uniform(
+ shape, self.minval, self.maxval, dtype, seed=self.seed)
def get_config(self):
- return {"minval": self.minval,
- "maxval": self.maxval,
- "seed": self.seed,
- "dtype": self.dtype.name}
+ return {
+ "minval": self.minval,
+ "maxval": self.maxval,
+ "seed": self.seed,
+ "dtype": self.dtype.name
+ }
class RandomNormal(Initializer):
@@ -266,14 +263,16 @@ class RandomNormal(Initializer):
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
- return random_ops.random_normal(shape, self.mean, self.stddev,
- dtype, seed=self.seed)
+ return random_ops.random_normal(
+ shape, self.mean, self.stddev, dtype, seed=self.seed)
def get_config(self):
- return {"mean": self.mean,
- "stddev": self.stddev,
- "seed": self.seed,
- "dtype": self.dtype.name}
+ return {
+ "mean": self.mean,
+ "stddev": self.stddev,
+ "seed": self.seed,
+ "dtype": self.dtype.name
+ }
class TruncatedNormal(Initializer):
@@ -304,14 +303,16 @@ class TruncatedNormal(Initializer):
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
- return random_ops.truncated_normal(shape, self.mean, self.stddev,
- dtype, seed=self.seed)
+ return random_ops.truncated_normal(
+ shape, self.mean, self.stddev, dtype, seed=self.seed)
def get_config(self):
- return {"mean": self.mean,
- "stddev": self.stddev,
- "seed": self.seed,
- "dtype": self.dtype.name}
+ return {
+ "mean": self.mean,
+ "stddev": self.stddev,
+ "seed": self.seed,
+ "dtype": self.dtype.name
+ }
class UniformUnitScaling(Initializer):
@@ -362,13 +363,11 @@ class UniformUnitScaling(Initializer):
# Avoid errors when initializing zero-size tensors.
input_size = max(input_size, 1.0)
max_val = math.sqrt(3 / input_size) * self.factor
- return random_ops.random_uniform(shape, -max_val, max_val,
- dtype, seed=self.seed)
+ return random_ops.random_uniform(
+ shape, -max_val, max_val, dtype, seed=self.seed)
def get_config(self):
- return {"factor": self.factor,
- "seed": self.seed,
- "dtype": self.dtype.name}
+ return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
class VarianceScaling(Initializer):
@@ -398,7 +397,8 @@ class VarianceScaling(Initializer):
"distribution" arguments.
"""
- def __init__(self, scale=1.0,
+ def __init__(self,
+ scale=1.0,
mode="fan_in",
distribution="normal",
seed=None,
@@ -432,27 +432,31 @@ class VarianceScaling(Initializer):
scale /= max(1., (fan_in + fan_out) / 2.)
if self.distribution == "normal":
stddev = math.sqrt(scale)
- return random_ops.truncated_normal(shape, 0.0, stddev,
- dtype, seed=self.seed)
+ return random_ops.truncated_normal(
+ shape, 0.0, stddev, dtype, seed=self.seed)
else:
limit = math.sqrt(3.0 * scale)
- return random_ops.random_uniform(shape, -limit, limit,
- dtype, seed=self.seed)
+ return random_ops.random_uniform(
+ shape, -limit, limit, dtype, seed=self.seed)
def get_config(self):
- return {"scale": self.scale,
- "mode": self.mode,
- "distribution": self.distribution,
- "seed": self.seed,
- "dtype": self.dtype.name}
+ return {
+ "scale": self.scale,
+ "mode": self.mode,
+ "distribution": self.distribution,
+ "seed": self.seed,
+ "dtype": self.dtype.name
+ }
class Orthogonal(Initializer):
"""Initializer that generates an orthogonal matrix.
If the shape of the tensor to initialize is two-dimensional, i is initialized
- with an orthogonal matrix obtained from the singular value decomposition of a
- matrix of uniform random numbers.
+ with an orthogonal matrix obtained from the QR decomposition of a matrix of
+ uniform random numbers. If the matrix has fewer rows than columns then the
+ output will have orthogonal rows. Otherwise, the output will have orthogonal
+ columns.
If the shape of the tensor to initialize is more than two-dimensional,
a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
@@ -485,27 +489,23 @@ class Orthogonal(Initializer):
for dim in shape[:-1]:
num_rows *= dim
num_cols = shape[-1]
- flat_shape = (num_rows, num_cols)
+ flat_shape = (num_cols, num_rows) if num_rows < num_cols else (num_rows,
+ num_cols)
# Generate a random matrix
a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
# Compute the qr factorization
q, r = linalg_ops.qr(a, full_matrices=False)
# Make Q uniform
- square_len = math_ops.minimum(num_rows, num_cols)
- d = array_ops.diag_part(r[:square_len, :square_len])
+ d = array_ops.diag_part(r)
ph = d / math_ops.abs(d)
q *= ph
- # Pad zeros to Q (if rows smaller than cols)
if num_rows < num_cols:
- padding = array_ops.zeros([num_rows, num_cols - num_rows], dtype=dtype)
- q = array_ops.concat([q, padding], 1)
+ q = array_ops.matrix_transpose(q)
return self.gain * array_ops.reshape(q, shape)
def get_config(self):
- return {"gain": self.gain,
- "seed": self.seed,
- "dtype": self.dtype.name}
+ return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
# Aliases.
@@ -520,6 +520,7 @@ truncated_normal_initializer = TruncatedNormal
uniform_unit_scaling_initializer = UniformUnitScaling
variance_scaling_initializer = VarianceScaling
orthogonal_initializer = Orthogonal
+
# pylint: enable=invalid-name
@@ -542,11 +543,8 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
Returns:
An initializer.
"""
- return variance_scaling_initializer(scale=1.0,
- mode="fan_avg",
- distribution="uniform",
- seed=seed,
- dtype=dtype)
+ return variance_scaling_initializer(
+ scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
@@ -568,11 +566,8 @@ def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
Returns:
An initializer.
"""
- return variance_scaling_initializer(scale=1.0,
- mode="fan_avg",
- distribution="normal",
- seed=seed,
- dtype=dtype)
+ return variance_scaling_initializer(
+ scale=1.0, mode="fan_avg", distribution="normal", seed=seed, dtype=dtype)
# Utility functions.