aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-12 11:51:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-12 11:54:06 -0700
commit454a22aa29dc2dba355094aabe733cd8419f2788 (patch)
treed3477592a3abee37e0f19372c10ac8359748a9ce
parent10e60219b71fc48e07b0afaa6edeec2d9afac24d (diff)
Construct Orthogonal kernels for 2d convolutions.
PiperOrigin-RevId: 192645769
-rw-r--r--tensorflow/contrib/framework/__init__.py2
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py99
-rw-r--r--tensorflow/python/ops/init_ops.py186
3 files changed, 282 insertions, 5 deletions
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index cbb68bd3eb..a52907f163 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -72,6 +72,7 @@ See the @{$python/contrib.framework} guide.
@@variable
@@VariableDeviceChooser
@@convolutional_delta_orthogonal
+@@convolutional_orthogonal_2d
@@zero_initializer
@@load_checkpoint
@@ -116,6 +117,7 @@ from tensorflow.python.framework.smart_cond import smart_constant_value
from tensorflow.python.framework.tensor_spec import BoundedTensorSpec
from tensorflow.python.framework.tensor_spec import TensorSpec
from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal
+from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['nest']
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 1e5c118cbc..f7a7119b34 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -551,7 +551,6 @@ class OrthogonalInitializerTest(test.TestCase):
init2 = init_ops.orthogonal_initializer(gain=3.14, seed=1, dtype=dtype)
with self.test_session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -610,7 +609,6 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
seed=1, dtype=dtype)
with self.test_session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -674,6 +672,103 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
self.assertAllClose(abs_value, count, rtol=tol, atol=tol)
+class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
+
+ def testInitializerIdentical(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ init1 = init_ops.convolutional_orthogonal_2d(seed=1, dtype=dtype)
+ init2 = init_ops.convolutional_orthogonal_2d(seed=1, dtype=dtype)
+ self.assertTrue(identicaltest(self, init1, init2, (3, 3, 10, 10)))
+
+ def testInitializerDifferent(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ init1 = init_ops.convolutional_orthogonal_2d(seed=1, dtype=dtype)
+ init2 = init_ops.convolutional_orthogonal_2d(seed=2, dtype=dtype)
+ self.assertFalse(identicaltest(self, init1, init2, (3, 3, 10, 10)))
+
+ def testDuplicatedInitializer(self):
+ init = init_ops.convolutional_orthogonal_2d()
+ self.assertFalse(duplicated_initializer(self, init, 1, (3, 3, 10, 10)))
+
+ def testInvalidDataType(self):
+ self.assertRaises(
+ ValueError, init_ops.convolutional_orthogonal_2d,
+ dtype=dtypes.string)
+
+ def testInvalidShape(self):
+ init1 = init_ops.convolutional_orthogonal_2d()
+ with self.test_session(graph=ops.Graph(), use_gpu=True):
+ self.assertRaises(ValueError, init1, shape=[3, 3, 6, 5])
+
+ def testGain(self):
+ shape = (3, 3, 10, 10)
+ for dtype in [dtypes.float32, dtypes.float64]:
+ init1 = init_ops.convolutional_orthogonal_2d(seed=1, dtype=dtype)
+ init2 = init_ops.convolutional_orthogonal_2d(gain=3.14,
+ seed=1, dtype=dtype)
+ with self.test_session(graph=ops.Graph(), use_gpu=True):
+ t1 = init1(shape).eval()
+ t2 = init2(shape).eval()
+ return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
+
+ def testShapesValues(self):
+ def circular_pad(input_, width, kernel_size):
+ """Pad input_ for computing (circular) convolution.
+
+ Args:
+ input_: the input tensor
+ width: the width of the tensor.
+ kernel_size: the kernel size of the filter.
+ Returns:
+ a tensor whose width is (width + kernel_size - 1).
+ """
+ beg = kernel_size // 2
+ end = kernel_size - 1 - beg
+
+ tmp_up = array_ops.slice(input_, [0, width - beg, 0, 0],
+ [-1, beg, width, -1])
+ tmp_down = array_ops.slice(input_, [0, 0, 0, 0], [-1, end, width, -1])
+ tmp = array_ops.concat([tmp_up, input_, tmp_down], 1)
+
+ new_width = width + kernel_size - 1
+ tmp_left = array_ops.slice(tmp, [0, 0, width - beg, 0],
+ [-1, new_width, beg, -1])
+ tmp_right = array_ops.slice(tmp, [0, 0, 0, 0], [-1, new_width, end, -1])
+
+ final = array_ops.concat([tmp_left, tmp, tmp_right], 2)
+ return final
+
+ cout = 45
+ shape = [64, 28, 28, 32]
+ outputs_shape = shape[0:-1] + [cout]
+ dtype = dtypes.float32
+ tol = 1e-3
+ gain = 3.14
+ # Check orthogonality/isometry by computing the ratio between
+ # the 2-norms of the inputs and ouputs.
+ for kernel_size in [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]:
+ convolution = convolutional.conv2d
+ inputs = random_ops.random_normal(shape, dtype=dtype)
+ inputs_2norm = linalg_ops.norm(inputs)
+ input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0])
+ outputs = convolution(
+ input_with_circular_pad, padding="valid", filters=cout,
+ kernel_size=kernel_size, use_bias=False,
+ kernel_initializer=init_ops.convolutional_orthogonal_2d(gain=gain))
+ outputs_2norm = linalg_ops.norm(outputs)
+ my_ops = variables.global_variables_initializer()
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(my_ops)
+ # Check the shape of the outputs
+ t = outputs.eval()
+ self.assertAllEqual(t.shape, outputs_shape)
+ # Check isometry of the orthogonal kernel.
+ self.assertAllClose(
+ sess.run(inputs_2norm)/np.sqrt(np.prod(shape)),
+ sess.run(outputs_2norm)/(np.sqrt(np.prod(shape))*np.sqrt(gain)),
+ rtol=tol, atol=tol)
+
+
class IdentityInitializerTest(test.TestCase):
def testInvalidDataType(self):
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 9dfe5ffbf4..5ded3f7cc2 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -499,10 +499,10 @@ class Orthogonal(Initializer):
Args:
gain: multiplicative factor to apply to the orthogonal matrix
- dtype: The type of the output.
seed: A Python integer. Used to create random seeds. See
@{tf.set_random_seed}
for behavior.
+ dtype: The data type.
"""
def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
@@ -552,10 +552,10 @@ class ConvolutionDeltaOrthogonal(Initializer):
gain: multiplicative factor to apply to the orthogonal matrix. Default is 1.
The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after
applying this convolution.
- dtype: The type of the output.
seed: A Python integer. Used to create random seeds. See
@{tf.set_random_seed}
for behavior.
+ dtype: The data type.
"""
def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
@@ -581,7 +581,6 @@ class ConvolutionDeltaOrthogonal(Initializer):
q, r = linalg_ops.qr(a, full_matrices=False)
# Make Q uniform
d = array_ops.diag_part(r)
- # ph = d / math_ops.abs(d)
q *= math_ops.sign(d)
q = q[:shape[-2], :]
q *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype))
@@ -601,6 +600,186 @@ class ConvolutionDeltaOrthogonal(Initializer):
return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
+class ConvolutionOrthogonal2D(Initializer):
+ """Initializer that generates a 2D orthogonal kernel for ConvNets.
+
+ The shape of the tensor must have length 2. The number of input
+ filters must not exceed the number of output filters.
+ The orthogonality(==isometry) is exact when the inputs are circular padded.
+ There are finite-width effects with non-circular padding (e.g. zero padding).
+
+ Args:
+ gain: multiplicative factor to apply to the orthogonal matrix. Default is 1.
+ The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after
+ applying this convolution.
+ seed: A Python integer. Used to create random seeds. See
+ @{tf.set_random_seed}
+ for behavior.
+ dtype: The data type.
+ """
+
+ def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
+ self.gain = gain
+ self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
+ self.seed = seed
+
+ def __call__(self, shape, dtype=None, partition_info=None):
+ if dtype is None:
+ dtype = self.dtype
+ # Check the shape
+ if len(shape) != 4:
+ raise ValueError("The tensor to initialize must be four-dimensional")
+
+ if shape[-2] > shape[-1]:
+ raise ValueError("In_filters cannot be greater than out_filters.")
+
+ if shape[0] != shape[1]:
+ raise ValueError("Kernel sizes must be equal.")
+
+ kernel = self._orthogonal_kernel(shape[0], shape[2], shape[3])
+ kernel *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype))
+ return kernel
+
+ def get_config(self):
+ return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
+
+ # Helper functions.
+ def _orthogonal_matrix(self, n):
+ """Construct an n x n orthogonal matrix.
+
+ Args:
+ n: dimension.
+ Returns:
+ a n x n orthogonal matrix.
+ """
+ a = random_ops.random_normal([n, n], dtype=self.dtype, seed=self.seed)
+ if self.seed:
+ self.seed += 1
+ q, r = linalg_ops.qr(a)
+ d = array_ops.diag_part(r)
+ # make q uniform
+ q *= math_ops.sign(d)
+ return q
+
+ def _symmetric_projection(self, n):
+ """Compute a n x n symmetric projection matrix.
+
+ Args:
+ n: dimension.
+ Returns:
+ a n x n symmetric projection matrix, i.e. a matrix P s.t. P=P*P, P=P^T.
+ """
+ q = self._orthogonal_matrix(n)
+ # randomly zeroing out some columns
+ mask = math_ops.cast(random_ops.random_normal([n], seed=self.seed) > 0,
+ self.dtype)
+ if self.seed:
+ self.seed += 1
+ c = math_ops.multiply(q, mask)
+ return math_ops.matmul(c, array_ops.matrix_transpose(c))
+
+ def _dict_to_tensor(self, x, k1, k2):
+ """Convert a dictionary to a tensor.
+
+ Args:
+ x: a k1 * k2 dictionary.
+ k1: first dimension of x.
+ k2: second dimension of x.
+ Returns:
+ a k1 * k2 tensor.
+ """
+
+ return array_ops.stack([array_ops.stack([x[i, j] for j in range(k2)])
+ for i in range(k1)])
+
+ def _block_orth(self, p1, p2):
+ """Construct a 2 x 2 kernel. Used to construct orthgonal kernel.
+
+ Args:
+ p1: a symmetric projection matrix
+ p2: a symmetric projection matrix
+ Returns:
+ a 2 x 2 kernel [[p1p2, p1(1-p2)],
+ [(1-p1)p2, (1-p1)(1-p2)]].
+ Raises:
+ ValueError: if the dimensions of p1 and p2 are different.
+ """
+ if p1.shape.as_list() != p2.shape.as_list():
+ raise ValueError("The dimension of the matrices must be the same.")
+ n = p1.shape.as_list()[0]
+ kernel2x2 = {}
+ eye = linalg_ops.eye(n, dtype=self.dtype)
+ kernel2x2[0, 0] = math_ops.matmul(p1, p2)
+ kernel2x2[0, 1] = math_ops.matmul(p1, (eye - p2))
+ kernel2x2[1, 0] = math_ops.matmul((eye - p1), p2)
+ kernel2x2[1, 1] = math_ops.matmul((eye - p1), (eye - p2))
+
+ return kernel2x2
+
+ def _matrix_conv(self, m1, m2):
+ """Matrix convolution.
+
+ Args:
+ m1: is a k x k dictionary, each element is a n x n matrix.
+ m2: is a l x l dictionary, each element is a n x n matrix.
+
+ Returns:
+ (k + l - 1) * (k + l - 1) dictionary each element is a n x n matrix.
+ Raises:
+ ValueError: if the entries of m1 and m2 are of different dimensions.
+ """
+
+ n = (m1[0, 0]).shape.as_list()[0]
+ if n != (m2[0, 0]).shape.as_list()[0]:
+ raise ValueError("The entries in matrices m1 and m2 "
+ "must have the same dimensions!")
+ k = int(np.sqrt(len(m1)))
+ l = int(np.sqrt(len(m2)))
+ result = {}
+ size = k + l - 1
+ # Compute matrix convolution between m1 and m2.
+ for i in range(size):
+ for j in range(size):
+ result[i, j] = array_ops.zeros([n, n], self.dtype)
+ for index1 in range(min(k, i + 1)):
+ for index2 in range(min(k, j + 1)):
+ if (i - index1) < l and (j - index2) < l:
+ result[i, j] += math_ops.matmul(m1[index1, index2],
+ m2[i - index1, j - index2])
+ return result
+
+ def _orthogonal_kernel(self, ksize, cin, cout):
+ """Construct orthogonal kernel for convolution.
+
+ Args:
+ ksize: kernel size
+ cin: number of input channels
+ cout: number of output channels
+ Returns:
+ an [ksize, ksize, cin, cout] orthogonal kernel.
+ Raises:
+ ValueError: if cin > cout.
+ """
+ if cin > cout:
+ raise ValueError("The number of input channels cannot exceed "
+ "the number of output channels.")
+ orth = self._orthogonal_matrix(cout)[0:cin, :]
+ if ksize == 1:
+ return array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0)
+
+ p = self._block_orth(self._symmetric_projection(cout),
+ self._symmetric_projection(cout))
+ for _ in range(ksize - 2):
+ temp = self._block_orth(self._symmetric_projection(cout),
+ self._symmetric_projection(cout))
+ p = self._matrix_conv(p, temp)
+ for i in range(ksize):
+ for j in range(ksize):
+ p[i, j] = math_ops.matmul(orth, p[i, j])
+
+ return self._dict_to_tensor(p, ksize, ksize)
+
+
@tf_export("keras.initializers.Identity", "initializers.identity")
class Identity(Initializer):
"""Initializer that generates the identity matrix.
@@ -646,6 +825,7 @@ variance_scaling_initializer = VarianceScaling
orthogonal_initializer = Orthogonal
identity_initializer = Identity
convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal
+convolutional_orthogonal_2d = ConvolutionOrthogonal2D
# pylint: enable=invalid-name