aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-20 00:22:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 00:31:37 -0700
commit2ea398b12ed18b6c51e09f363021c6aa306c5179 (patch)
tree15fc14da174b1f8e7db629562d4ffea50b81b978 /tensorflow/compiler
parentfcfc5ad738b1521aa70aaad323079eb72493dcad (diff)
Add feature_group_count parameter of Convolution op to xla_client.py.
This parameter has been added to HLO to support depthwise convolution. PiperOrigin-RevId: 213761790
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc6
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h3
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py19
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py24
4 files changed, 43 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 9da5dc0d2d..cd5fd33029 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -469,9 +469,11 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated(
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
- lhs_dilation, rhs_dilation, dimension_numbers);
+ lhs_dilation, rhs_dilation, dimension_numbers,
+ feature_group_count);
}
LocalOp LocalComputationBuilder::ConvertElementType(
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 1d5dfe5911..2166bb6721 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -248,7 +248,8 @@ class LocalComputationBuilder {
absl::Span<const std::pair<int64, int64> > padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
LocalOp ConvertElementType(const LocalOp& operand,
PrimitiveType new_element_type);
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index fa4366ff07..bb303c5678 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -1109,7 +1109,7 @@ class ComputationBuilder(object):
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
return self._client.DotGeneral(lhs, rhs, dimension_numbers)
- def Conv(self, lhs, rhs, window_strides, padding):
+ def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1):
"""Enqueues a Conv operation onto the computation.
Args:
@@ -1117,6 +1117,7 @@ class ComputationBuilder(object):
rhs: LocalOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the Conv operation.
"""
@@ -1125,10 +1126,11 @@ class ComputationBuilder(object):
self.GetShape(rhs).dimensions()[2:], window_strides)
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (),
- (), dimension_numbers)
+ (), dimension_numbers,
+ feature_group_count)
def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
- lhs_dilation, rhs_dilation):
+ lhs_dilation, rhs_dilation, feature_group_count=1):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
@@ -1138,6 +1140,7 @@ class ComputationBuilder(object):
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of dilation factors.
rhs_dilation: length-N array-like of dilation factors.
+ feature_group_count: number of feature groups for grouped convolution.
Returns:
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
@@ -1145,7 +1148,8 @@ class ComputationBuilder(object):
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
@@ -1163,7 +1167,8 @@ class ComputationBuilder(object):
return dimension_numbers
def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation,
- rhs_dilation, dimension_numbers):
+ rhs_dilation, dimension_numbers,
+ feature_group_count=1):
"""Enqueues a ConvGeneralDilated operation onto the computation.
Args:
@@ -1190,6 +1195,7 @@ class ComputationBuilder(object):
labels appear in the rhs_spec string, so that window_strides[0] is
matched with the dimension corresponding to the first character
appearing in rhs_spec that is not 'I' or 'O'.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the ConvGenralDilated operation.
"""
@@ -1215,7 +1221,8 @@ class ComputationBuilder(object):
key=lambda i: rhs_spec.index(out_spec[i])))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def Sort(self, operand, dimension=-1):
"""Enqueues a sort operation onto the computation."""
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index fd98e19457..82103f0313 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -661,6 +661,30 @@ class SingleOpTest(LocalComputationTest):
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
+ def testConvGeneralDilatedGroupedConvolutionF32(self):
+ c = self._NewComputation()
+ a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+ lhs = a(1, 2, 2, 3)
+ rhs = a(2, 1, 1, 2) * 10
+ strides = [1, 1]
+ pads = [(1, 0), (0, 1)]
+ lhs_dilation = (2, 1)
+ rhs_dilation = (1, 1)
+ dimension_numbers = ("NCHW", "OIHW", "NCHW")
+ feature_group_count = 2
+ c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
+ strides, pads, lhs_dilation, rhs_dilation,
+ dimension_numbers, feature_group_count)
+ result = np.array([[[[0., 0., 0.],
+ [10., 20., 0.],
+ [0., 0., 0.],
+ [40., 50., 0.]],
+ [[0., 0., 0.],
+ [330., 380., 160.],
+ [0., 0., 0.],
+ [480., 530., 220.]]]])
+ self._ExecuteAndCompareClose(c, expected=result)
+
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])