aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r--tensorflow/python/ops/nn_ops.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index a563e7c588..8c1083d9cc 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -452,6 +452,7 @@ class _WithSpaceToBatch(object):
self.input_shape = input_shape
self.spatial_dims = spatial_dims
self.dilation_rate = dilation_rate
+ self.data_format = data_format
self.op = build_op(num_spatial_dims, "VALID")
self.call = self._with_space_to_batch_call
@@ -496,6 +497,14 @@ class _WithSpaceToBatch(object):
result_converted = array_ops.batch_to_space_nd(
input=result, block_shape=dilation_rate, crops=crops)
+
+ # Recover channel information for output shape if channels are not last.
+ if self.data_format is not None and self.data_format.startswith("NC"):
+ if not result_converted.shape[1].value:
+ output_shape = result_converted.shape.as_list()
+ output_shape[1] = filter.shape[-1]
+ result_converted.set_shape(output_shape)
+
return result_converted
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
@@ -823,7 +832,8 @@ class Convolution(object):
padding=padding,
build_op=self._build_op,
filter_shape=filter_shape,
- spatial_dims=spatial_dims)
+ spatial_dims=spatial_dims,
+ data_format=data_format)
def _build_op(self, _, padding):
return _NonAtrousConvolution(