aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 17:58:28 -0800
committerGravatar Michael Case <mikecase@google.com>2018-01-31 16:36:36 -0800
commit4c2ee0464a7c50afa1db3973215b2c7df8ade0e4 (patch)
treeb8c27a4845c4e292183cfc0ef1eec22716dd4c2e /tensorflow/contrib/kfac
parentf6d3f0a5506ff221ba52c4f800d57c8d6b47643d (diff)
K-FAC: Wrap extract_image_patches() for compatibility with XLA.
PiperOrigin-RevId: 183923073
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py63
1 files changed, 55 insertions, 8 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index f59168cbc0..bcba18ae14 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -111,6 +112,54 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di
return array_ops.ones(shape, dtype)
+def extract_image_patches(image, ksizes, strides, padding, name=None):
+ """Extracts image patches for an N-dimensional convolution.
+
+ This function is a compatibility wrapper over tf.extract_image_patches(), as
+ ExtractImagePatches isn't yet implemented in XLA.
+
+ Args:
+ image: Tensor of shape [batch, in_x, in_y, ..., in_channels]. Input images.
+ All dimensions except 'batch' must be defined.
+ ksizes: [filter_x, filter_y, ...]. Spatial shape of filter in each
+ dimension.
+ strides: [stride_x, stride_y, ...]. Spatial stride for filter in each
+ dimension.
+ padding: str. "VALID" or "SAME".
+ name: str or None. name of Op.
+
+ Returns:
+ result: [batch, out_x, out_y, ..., filter_x, filter_y, ..., in_channels].
+ Contains image patches to which conv kernel would be applied for each
+ output location. [out_x, out_y, ...] depends on padding.
+ """
+ if not utils.on_tpu():
+ return array_ops.extract_image_patches(
+ image,
+ ksizes=([1] + list(ksizes) + [1]),
+ strides=([1] + list(strides) + [1]),
+ rates=[1, 1, 1, 1],
+ padding=padding,
+ name=name)
+
+ with tf_ops.name_scope(name, "extract_image_patches",
+ [image, ksizes, strides, padding]):
+ batch = image.shape.as_list()[0]
+ in_channels = image.shape.as_list()[-1]
+
+ # Map each input feature to a location in the output.
+ out_channels = np.prod(ksizes) * in_channels
+ filters = linalg_ops.eye(out_channels),
+ filters = array_ops.reshape(filters, ksizes + [in_channels, out_channels])
+
+ result = nn.convolution(image, filters, padding, strides=strides)
+ out_spatial = result.shape.as_list()[1:-1]
+ result = array_ops.reshape(
+ result, [batch or -1] + out_spatial + ksizes + [in_channels])
+
+ return result
+
+
def compute_cov(tensor, tensor_right=None, normalizer=None):
"""Compute the empirical second moment of the rows of a 2D Tensor.
@@ -668,11 +717,10 @@ class ConvDiagonalFactor(DiagonalFactor):
# TODO(b/64144716): there is potential here for a big savings in terms
# of memory use.
- patches = array_ops.extract_image_patches(
+ patches = extract_image_patches(
self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=[1, 1, 1, 1],
+ ksizes=[filter_height, filter_width],
+ strides=self._strides[1:-1],
padding=self._padding)
if self._has_bias:
@@ -816,11 +864,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
# TODO(b/64144716): there is potential here for a big savings in terms of
# memory use.
- patches = array_ops.extract_image_patches(
+ patches = extract_image_patches(
self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=[1, 1, 1, 1],
+ ksizes=[filter_height, filter_width],
+ strides=self._strides[1:-1],
padding=self._padding)
flatten_size = (filter_height * filter_width * in_channels)