From d071825f3f02a835a7ccfc396e1955b7e47f1cbd Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Fri, 11 Nov 2016 14:53:04 -0800 Subject: Switch FusedConvPad to use C++ shape function. Verified against python implementation using debug_python_shape_fn on optimize_for_inference_test. Change: 138921973 --- tensorflow/core/ops/nn_ops.cc | 101 ++++++++++++++++++++++++++++++++++++++++ tensorflow/python/ops/nn_ops.py | 77 +----------------------------- 2 files changed, 103 insertions(+), 75 deletions(-) diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 8d04ae85c6..02440bd626 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -627,6 +627,101 @@ data_format: Specify the data format of the input and output data. With the [batch, in_channels, in_height, in_width]. )doc"); +namespace { + +Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + + ShapeHandle resized = input; + int paddings_index = 1; + int filter_index = 2; + if (has_resize) { + paddings_index = 2; + filter_index = 3; + + ShapeHandle unused_size; + TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size)); + + const Tensor* size = c->input_tensor(1); + DimensionHandle new_height = c->UnknownDim(); + DimensionHandle new_width = c->UnknownDim(); + if (size != nullptr) { + new_height = c->MakeDim(size->flat()(0)); + new_width = c->MakeDim(size->flat()(1)); + } + TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized)); + TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized)); + } + + ShapeHandle paddings; + TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings)); + TF_RETURN_IF_ERROR( + c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized)); + TF_RETURN_IF_ERROR( + c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings)); + + const Tensor* paddings_t = c->input_tensor(paddings_index); + ShapeHandle padded; + if (paddings_t != nullptr) { + std::vector output_dims; + for (int i = 0; i < 4; ++i) { + DimensionHandle dim = c->Dim(resized, i); + int64 p0 = static_cast(paddings_t->matrix()(i, 0)); + int64 p1 = static_cast(paddings_t->matrix()(i, 1)); + if (p0 < 0 || p1 < 0) { + return errors::InvalidArgument("Paddings must be non-negative"); + } + + TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim)); + output_dims.push_back(dim); + } + padded = c->MakeShape(output_dims); + } else { + padded = c->UnknownShapeOfRank(4); + } + + // Work out the convolution's effect with 'padded' as the input. + ShapeHandle filter; + TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter)); + std::vector strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 4) { + return errors::InvalidArgument( + "Operation requires the stride attribute to contain 4 values, but ", + "got: ", strides.size()); + } + + int32 stride_rows = strides[1]; + int32 stride_cols = strides[2]; + + DimensionHandle batch_size_dim = c->Dim(padded, 0); + DimensionHandle in_rows_dim = c->Dim(padded, 1); + DimensionHandle in_cols_dim = c->Dim(padded, 2); + DimensionHandle filter_rows_dim = c->Dim(filter, 0); + DimensionHandle filter_cols_dim = c->Dim(filter, 1); + DimensionHandle output_depth_dim = c->Dim(filter, 3); + + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused)); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + DimensionHandle output_rows, output_cols; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols)); + + ShapeHandle output_shape = c->MakeShape( + {batch_size_dim, output_rows, output_cols, output_depth_dim}); + c->set_output(0, output_shape); + return Status::OK(); +} + +} // namespace + REGISTER_OP("FusedResizeAndPadConv2D") .Input("input: T") .Input("size: int32") @@ -638,6 +733,9 @@ REGISTER_OP("FusedResizeAndPadConv2D") .Attr(GetMirrorPadModeAttrString()) .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + return CommonFusedConvCalculations(c, true /* has_resize */); + }) .Doc(R"doc( Performs a resize and padding as a preprocess during a convolution. @@ -676,6 +774,9 @@ REGISTER_OP("FusedPadConv2D") .Attr(GetMirrorPadModeAttrString()) .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + return CommonFusedConvCalculations(c, false /* has_resize */); + }) .Doc(R"doc( Performs a padding as a preprocess during a convolution. diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 0fdb98f172..4787a42537 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1666,87 +1666,14 @@ ops.RegisterShape("AvgPool")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("MaxPool")(common_shapes.call_cpp_shape_fn) -def _CommonFusedConvCalculations(op, has_resize): - """Shape function for Fused*Conv2D ops.""" - # The bilinear resize shape calculation. - input_shape = op.inputs[0].get_shape().with_rank(4) - if has_resize: - unused_size_shape = op.inputs[1].get_shape().merge_with([2]) - size = tensor_util.constant_value(op.inputs[1]) - if size is not None: - height = size[0] - width = size[1] - else: - height = None - width = None - resized_shape = tensor_shape.TensorShape( - [input_shape[0], height, width, input_shape[3]]) - paddings_index = 2 - filter_index = 3 - else: - resized_shape = input_shape - paddings_index = 1 - filter_index = 2 - - # Calculates the effect of the padding. - paddings_shape = op.inputs[paddings_index].get_shape().with_rank(2) - resized_shape = resized_shape.with_rank(paddings_shape[0].value) - paddings_shape = paddings_shape.merge_with( - tensor_shape.matrix(resized_shape.ndims, 2)) - paddings = tensor_util.constant_value(op.inputs[paddings_index]) - if paddings is None: - padded_shape = tensor_shape.unknown_shape(ndims=resized_shape.ndims) - else: - output_dims = [] - for i, dim in enumerate(resized_shape.dims): - if paddings[i, 0] < 0 or paddings[i, 1] < 0: - raise ValueError("paddings must be non-negative") - output_dims.append(dim + paddings[i, 0] + paddings[i, 1]) - padded_shape = tensor_shape.TensorShape(output_dims) - - # Finally work out the convolution's effect. - filter_shape = op.inputs[filter_index].get_shape().with_rank(4) - - batch_size = padded_shape[0] - in_rows = padded_shape[1] - in_cols = padded_shape[2] - - filter_rows = filter_shape[0] - filter_cols = filter_shape[1] - depth_out = filter_shape[3] - # Check that the input depths are compatible. - padded_shape[3].assert_is_compatible_with(filter_shape[2]) - - stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") - - if stride_b != 1 or stride_d != 1: - raise ValueError("Current implementation does not yet support " - "strides in the batch and depth dimensions.") - # TODO(mrry,shlens): Raise an error if the stride would cause - # information in the input to be ignored. This will require a change - # in the kernel implementation. - padding = op.get_attr("padding") - out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols, - filter_rows, - filter_cols, - stride_r, - stride_c, - padding) - - output_shape = [batch_size, out_rows, out_cols, depth_out] - return [tensor_shape.TensorShape(output_shape)] - - @ops.RegisterShape("FusedResizeAndPadConv2D") def _FusedResizeAndPadConv2DShape(op): - """Shape function for FusedResizeAndPadConv2D op.""" - return _CommonFusedConvCalculations(op, True) + return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1, 2]) @ops.RegisterShape("FusedPadConv2D") def _FusedPadConv2DShape(op): - """Shape function for FusedResizeAndPadConv2D op.""" - return _CommonFusedConvCalculations(op, False) + return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1]) ops.RegisterShape("MaxPoolWithArgmax")(common_shapes.call_cpp_shape_fn) -- cgit v1.2.3