diff options
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index a563e7c588..8c1083d9cc 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -452,6 +452,7 @@ class _WithSpaceToBatch(object): self.input_shape = input_shape self.spatial_dims = spatial_dims self.dilation_rate = dilation_rate + self.data_format = data_format self.op = build_op(num_spatial_dims, "VALID") self.call = self._with_space_to_batch_call @@ -496,6 +497,14 @@ class _WithSpaceToBatch(object): result_converted = array_ops.batch_to_space_nd( input=result, block_shape=dilation_rate, crops=crops) + + # Recover channel information for output shape if channels are not last. + if self.data_format is not None and self.data_format.startswith("NC"): + if not result_converted.shape[1].value: + output_shape = result_converted.shape.as_list() + output_shape[1] = filter.shape[-1] + result_converted.set_shape(output_shape) + return result_converted def __call__(self, inp, filter): # pylint: disable=redefined-builtin @@ -823,7 +832,8 @@ class Convolution(object): padding=padding, build_op=self._build_op, filter_shape=filter_shape, - spatial_dims=spatial_dims) + spatial_dims=spatial_dims, + data_format=data_format) def _build_op(self, _, padding): return _NonAtrousConvolution( |