diff options
author | 2017-07-28 09:33:43 -0700 | |
---|---|---|
committer | 2017-07-28 09:37:51 -0700 | |
commit | 7635e9db104a095918891c3be6f5d53669f075af (patch) | |
tree | 3e505fade3b47b2955fb6d9b3d940568a6ab4be7 | |
parent | 046f912bb5e18fae87cf96ec04d035dddde633ad (diff) |
Handle num_rows < num_cols case in orthogonal_initializer properly (as the original SVD-based implementation would) instead of just padding with zeros.
Run pyformat.
PiperOrigin-RevId: 163478869
-rw-r--r-- | tensorflow/python/ops/init_ops.py | 121 |
1 files changed, 58 insertions, 63 deletions
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 42b4f952bb..203a3cd485 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Operations often used for initializing tensors. All variable initializers returned by functions in this file should have the @@ -190,24 +189,20 @@ class Constant(Initializer): self.dtype = dtypes.as_dtype(dtype) self._verify_shape = verify_shape - def __call__(self, shape, - dtype=None, - partition_info=None, - verify_shape=None): + def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None): if dtype is None: dtype = self.dtype if verify_shape is None: verify_shape = self._verify_shape - return constant_op.constant(self.value, dtype=dtype, shape=shape, - verify_shape=verify_shape) + return constant_op.constant( + self.value, dtype=dtype, shape=shape, verify_shape=verify_shape) def get_config(self): # We don't include `verify_shape` for compatibility with Keras. # `verify_shape` should be passed as an argument to `__call__` rather # than as a constructor argument: conceptually it isn't a property # of the initializer. - return {"value": self.value, - "dtype": self.dtype.name} + return {"value": self.value, "dtype": self.dtype.name} class RandomUniform(Initializer): @@ -233,14 +228,16 @@ class RandomUniform(Initializer): def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: dtype = self.dtype - return random_ops.random_uniform(shape, self.minval, self.maxval, - dtype, seed=self.seed) + return random_ops.random_uniform( + shape, self.minval, self.maxval, dtype, seed=self.seed) def get_config(self): - return {"minval": self.minval, - "maxval": self.maxval, - "seed": self.seed, - "dtype": self.dtype.name} + return { + "minval": self.minval, + "maxval": self.maxval, + "seed": self.seed, + "dtype": self.dtype.name + } class RandomNormal(Initializer): @@ -266,14 +263,16 @@ class RandomNormal(Initializer): def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: dtype = self.dtype - return random_ops.random_normal(shape, self.mean, self.stddev, - dtype, seed=self.seed) + return random_ops.random_normal( + shape, self.mean, self.stddev, dtype, seed=self.seed) def get_config(self): - return {"mean": self.mean, - "stddev": self.stddev, - "seed": self.seed, - "dtype": self.dtype.name} + return { + "mean": self.mean, + "stddev": self.stddev, + "seed": self.seed, + "dtype": self.dtype.name + } class TruncatedNormal(Initializer): @@ -304,14 +303,16 @@ class TruncatedNormal(Initializer): def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: dtype = self.dtype - return random_ops.truncated_normal(shape, self.mean, self.stddev, - dtype, seed=self.seed) + return random_ops.truncated_normal( + shape, self.mean, self.stddev, dtype, seed=self.seed) def get_config(self): - return {"mean": self.mean, - "stddev": self.stddev, - "seed": self.seed, - "dtype": self.dtype.name} + return { + "mean": self.mean, + "stddev": self.stddev, + "seed": self.seed, + "dtype": self.dtype.name + } class UniformUnitScaling(Initializer): @@ -362,13 +363,11 @@ class UniformUnitScaling(Initializer): # Avoid errors when initializing zero-size tensors. input_size = max(input_size, 1.0) max_val = math.sqrt(3 / input_size) * self.factor - return random_ops.random_uniform(shape, -max_val, max_val, - dtype, seed=self.seed) + return random_ops.random_uniform( + shape, -max_val, max_val, dtype, seed=self.seed) def get_config(self): - return {"factor": self.factor, - "seed": self.seed, - "dtype": self.dtype.name} + return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name} class VarianceScaling(Initializer): @@ -398,7 +397,8 @@ class VarianceScaling(Initializer): "distribution" arguments. """ - def __init__(self, scale=1.0, + def __init__(self, + scale=1.0, mode="fan_in", distribution="normal", seed=None, @@ -432,27 +432,31 @@ class VarianceScaling(Initializer): scale /= max(1., (fan_in + fan_out) / 2.) if self.distribution == "normal": stddev = math.sqrt(scale) - return random_ops.truncated_normal(shape, 0.0, stddev, - dtype, seed=self.seed) + return random_ops.truncated_normal( + shape, 0.0, stddev, dtype, seed=self.seed) else: limit = math.sqrt(3.0 * scale) - return random_ops.random_uniform(shape, -limit, limit, - dtype, seed=self.seed) + return random_ops.random_uniform( + shape, -limit, limit, dtype, seed=self.seed) def get_config(self): - return {"scale": self.scale, - "mode": self.mode, - "distribution": self.distribution, - "seed": self.seed, - "dtype": self.dtype.name} + return { + "scale": self.scale, + "mode": self.mode, + "distribution": self.distribution, + "seed": self.seed, + "dtype": self.dtype.name + } class Orthogonal(Initializer): """Initializer that generates an orthogonal matrix. If the shape of the tensor to initialize is two-dimensional, i is initialized - with an orthogonal matrix obtained from the singular value decomposition of a - matrix of uniform random numbers. + with an orthogonal matrix obtained from the QR decomposition of a matrix of + uniform random numbers. If the matrix has fewer rows than columns then the + output will have orthogonal rows. Otherwise, the output will have orthogonal + columns. If the shape of the tensor to initialize is more than two-dimensional, a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])` @@ -485,27 +489,23 @@ class Orthogonal(Initializer): for dim in shape[:-1]: num_rows *= dim num_cols = shape[-1] - flat_shape = (num_rows, num_cols) + flat_shape = (num_cols, num_rows) if num_rows < num_cols else (num_rows, + num_cols) # Generate a random matrix a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed) # Compute the qr factorization q, r = linalg_ops.qr(a, full_matrices=False) # Make Q uniform - square_len = math_ops.minimum(num_rows, num_cols) - d = array_ops.diag_part(r[:square_len, :square_len]) + d = array_ops.diag_part(r) ph = d / math_ops.abs(d) q *= ph - # Pad zeros to Q (if rows smaller than cols) if num_rows < num_cols: - padding = array_ops.zeros([num_rows, num_cols - num_rows], dtype=dtype) - q = array_ops.concat([q, padding], 1) + q = array_ops.matrix_transpose(q) return self.gain * array_ops.reshape(q, shape) def get_config(self): - return {"gain": self.gain, - "seed": self.seed, - "dtype": self.dtype.name} + return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name} # Aliases. @@ -520,6 +520,7 @@ truncated_normal_initializer = TruncatedNormal uniform_unit_scaling_initializer = UniformUnitScaling variance_scaling_initializer = VarianceScaling orthogonal_initializer = Orthogonal + # pylint: enable=invalid-name @@ -542,11 +543,8 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32): Returns: An initializer. """ - return variance_scaling_initializer(scale=1.0, - mode="fan_avg", - distribution="uniform", - seed=seed, - dtype=dtype) + return variance_scaling_initializer( + scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype) def glorot_normal_initializer(seed=None, dtype=dtypes.float32): @@ -568,11 +566,8 @@ def glorot_normal_initializer(seed=None, dtype=dtypes.float32): Returns: An initializer. """ - return variance_scaling_initializer(scale=1.0, - mode="fan_avg", - distribution="normal", - seed=seed, - dtype=dtype) + return variance_scaling_initializer( + scale=1.0, mode="fan_avg", distribution="normal", seed=seed, dtype=dtype) # Utility functions. |