diff options
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 48 |
1 files changed, 33 insertions, 15 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 6486463b2c..b70cf83d91 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1429,28 +1429,34 @@ ops.RegisterShape("AvgPool")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("MaxPool")(common_shapes.call_cpp_shape_fn) -@ops.RegisterShape("FusedResizeAndPadConv2D") -def _FusedResizeAndPadConv2DShape(op): - """Shape function for FusedResizeAndPadConv2D op.""" +def _CommonFusedConvCalculations(op, has_resize): + """Shape function for Fused*Conv2D ops.""" # The bilinear resize shape calculation. input_shape = op.inputs[0].get_shape().with_rank(4) - unused_size_shape = op.inputs[1].get_shape().merge_with([2]) - size = tensor_util.constant_value(op.inputs[1]) - if size is not None: - height = size[0] - width = size[1] + if has_resize: + unused_size_shape = op.inputs[1].get_shape().merge_with([2]) + size = tensor_util.constant_value(op.inputs[1]) + if size is not None: + height = size[0] + width = size[1] + else: + height = None + width = None + resized_shape = tensor_shape.TensorShape( + [input_shape[0], height, width, input_shape[3]]) + paddings_index = 2 + filter_index = 3 else: - height = None - width = None - resized_shape = tensor_shape.TensorShape( - [input_shape[0], height, width, input_shape[3]]) + resized_shape = input_shape + paddings_index = 1 + filter_index = 2 # Calculates the effect of the padding. - paddings_shape = op.inputs[2].get_shape().with_rank(2) + paddings_shape = op.inputs[paddings_index].get_shape().with_rank(2) resized_shape = resized_shape.with_rank(paddings_shape[0].value) paddings_shape = paddings_shape.merge_with( tensor_shape.matrix(resized_shape.ndims, 2)) - paddings = tensor_util.constant_value(op.inputs[2]) + paddings = tensor_util.constant_value(op.inputs[paddings_index]) if paddings is None: padded_shape = tensor_shape.unknown_shape(ndims=resized_shape.ndims) else: @@ -1462,7 +1468,7 @@ def _FusedResizeAndPadConv2DShape(op): padded_shape = tensor_shape.TensorShape(output_dims) # Finally work out the convolution's effect. - filter_shape = op.inputs[3].get_shape().with_rank(4) + filter_shape = op.inputs[filter_index].get_shape().with_rank(4) batch_size = padded_shape[0] in_rows = padded_shape[1] @@ -1494,6 +1500,18 @@ def _FusedResizeAndPadConv2DShape(op): return [tensor_shape.TensorShape(output_shape)] +@ops.RegisterShape("FusedResizeAndPadConv2D") +def _FusedResizeAndPadConv2DShape(op): + """Shape function for FusedResizeAndPadConv2D op.""" + return _CommonFusedConvCalculations(op, True) + + +@ops.RegisterShape("FusedPadConv2D") +def _FusedPadConv2DShape(op): + """Shape function for FusedResizeAndPadConv2D op.""" + return _CommonFusedConvCalculations(op, False) + + ops.RegisterShape("MaxPoolWithArgmax")(common_shapes.call_cpp_shape_fn) |