aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-11-11 14:53:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-11 15:05:31 -0800
commitd071825f3f02a835a7ccfc396e1955b7e47f1cbd (patch)
treeb990288cbe673e0af9229e549aac75d9af835753
parentea26dae5de6be94126fc1744294872b3859d8ac4 (diff)
Switch FusedConvPad to use C++ shape function.
Verified against python implementation using debug_python_shape_fn on optimize_for_inference_test. Change: 138921973
-rw-r--r--tensorflow/core/ops/nn_ops.cc101
-rw-r--r--tensorflow/python/ops/nn_ops.py77
2 files changed, 103 insertions, 75 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 8d04ae85c6..02440bd626 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -627,6 +627,101 @@ data_format: Specify the data format of the input and output data. With the
[batch, in_channels, in_height, in_width].
)doc");
+namespace {
+
+Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) {
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+
+ ShapeHandle resized = input;
+ int paddings_index = 1;
+ int filter_index = 2;
+ if (has_resize) {
+ paddings_index = 2;
+ filter_index = 3;
+
+ ShapeHandle unused_size;
+ TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size));
+
+ const Tensor* size = c->input_tensor(1);
+ DimensionHandle new_height = c->UnknownDim();
+ DimensionHandle new_width = c->UnknownDim();
+ if (size != nullptr) {
+ new_height = c->MakeDim(size->flat<int32>()(0));
+ new_width = c->MakeDim(size->flat<int32>()(1));
+ }
+ TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized));
+ TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized));
+ }
+
+ ShapeHandle paddings;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings));
+ TF_RETURN_IF_ERROR(
+ c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized));
+ TF_RETURN_IF_ERROR(
+ c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings));
+
+ const Tensor* paddings_t = c->input_tensor(paddings_index);
+ ShapeHandle padded;
+ if (paddings_t != nullptr) {
+ std::vector<DimensionHandle> output_dims;
+ for (int i = 0; i < 4; ++i) {
+ DimensionHandle dim = c->Dim(resized, i);
+ int64 p0 = static_cast<int64>(paddings_t->matrix<int32>()(i, 0));
+ int64 p1 = static_cast<int64>(paddings_t->matrix<int32>()(i, 1));
+ if (p0 < 0 || p1 < 0) {
+ return errors::InvalidArgument("Paddings must be non-negative");
+ }
+
+ TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim));
+ output_dims.push_back(dim);
+ }
+ padded = c->MakeShape(output_dims);
+ } else {
+ padded = c->UnknownShapeOfRank(4);
+ }
+
+ // Work out the convolution's effect with 'padded' as the input.
+ ShapeHandle filter;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter));
+ std::vector<int32> strides;
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
+ if (strides.size() != 4) {
+ return errors::InvalidArgument(
+ "Operation requires the stride attribute to contain 4 values, but ",
+ "got: ", strides.size());
+ }
+
+ int32 stride_rows = strides[1];
+ int32 stride_cols = strides[2];
+
+ DimensionHandle batch_size_dim = c->Dim(padded, 0);
+ DimensionHandle in_rows_dim = c->Dim(padded, 1);
+ DimensionHandle in_cols_dim = c->Dim(padded, 2);
+ DimensionHandle filter_rows_dim = c->Dim(filter, 0);
+ DimensionHandle filter_cols_dim = c->Dim(filter, 1);
+ DimensionHandle output_depth_dim = c->Dim(filter, 3);
+
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused));
+
+ Padding padding;
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
+
+ DimensionHandle output_rows, output_cols;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
+ c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
+ c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
+
+ ShapeHandle output_shape = c->MakeShape(
+ {batch_size_dim, output_rows, output_cols, output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("FusedResizeAndPadConv2D")
.Input("input: T")
.Input("size: int32")
@@ -638,6 +733,9 @@ REGISTER_OP("FusedResizeAndPadConv2D")
.Attr(GetMirrorPadModeAttrString())
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ return CommonFusedConvCalculations(c, true /* has_resize */);
+ })
.Doc(R"doc(
Performs a resize and padding as a preprocess during a convolution.
@@ -676,6 +774,9 @@ REGISTER_OP("FusedPadConv2D")
.Attr(GetMirrorPadModeAttrString())
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ return CommonFusedConvCalculations(c, false /* has_resize */);
+ })
.Doc(R"doc(
Performs a padding as a preprocess during a convolution.
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 0fdb98f172..4787a42537 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1666,87 +1666,14 @@ ops.RegisterShape("AvgPool")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("MaxPool")(common_shapes.call_cpp_shape_fn)
-def _CommonFusedConvCalculations(op, has_resize):
- """Shape function for Fused*Conv2D ops."""
- # The bilinear resize shape calculation.
- input_shape = op.inputs[0].get_shape().with_rank(4)
- if has_resize:
- unused_size_shape = op.inputs[1].get_shape().merge_with([2])
- size = tensor_util.constant_value(op.inputs[1])
- if size is not None:
- height = size[0]
- width = size[1]
- else:
- height = None
- width = None
- resized_shape = tensor_shape.TensorShape(
- [input_shape[0], height, width, input_shape[3]])
- paddings_index = 2
- filter_index = 3
- else:
- resized_shape = input_shape
- paddings_index = 1
- filter_index = 2
-
- # Calculates the effect of the padding.
- paddings_shape = op.inputs[paddings_index].get_shape().with_rank(2)
- resized_shape = resized_shape.with_rank(paddings_shape[0].value)
- paddings_shape = paddings_shape.merge_with(
- tensor_shape.matrix(resized_shape.ndims, 2))
- paddings = tensor_util.constant_value(op.inputs[paddings_index])
- if paddings is None:
- padded_shape = tensor_shape.unknown_shape(ndims=resized_shape.ndims)
- else:
- output_dims = []
- for i, dim in enumerate(resized_shape.dims):
- if paddings[i, 0] < 0 or paddings[i, 1] < 0:
- raise ValueError("paddings must be non-negative")
- output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
- padded_shape = tensor_shape.TensorShape(output_dims)
-
- # Finally work out the convolution's effect.
- filter_shape = op.inputs[filter_index].get_shape().with_rank(4)
-
- batch_size = padded_shape[0]
- in_rows = padded_shape[1]
- in_cols = padded_shape[2]
-
- filter_rows = filter_shape[0]
- filter_cols = filter_shape[1]
- depth_out = filter_shape[3]
- # Check that the input depths are compatible.
- padded_shape[3].assert_is_compatible_with(filter_shape[2])
-
- stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
-
- if stride_b != 1 or stride_d != 1:
- raise ValueError("Current implementation does not yet support "
- "strides in the batch and depth dimensions.")
- # TODO(mrry,shlens): Raise an error if the stride would cause
- # information in the input to be ignored. This will require a change
- # in the kernel implementation.
- padding = op.get_attr("padding")
- out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols,
- filter_rows,
- filter_cols,
- stride_r,
- stride_c,
- padding)
-
- output_shape = [batch_size, out_rows, out_cols, depth_out]
- return [tensor_shape.TensorShape(output_shape)]
-
-
@ops.RegisterShape("FusedResizeAndPadConv2D")
def _FusedResizeAndPadConv2DShape(op):
- """Shape function for FusedResizeAndPadConv2D op."""
- return _CommonFusedConvCalculations(op, True)
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1, 2])
@ops.RegisterShape("FusedPadConv2D")
def _FusedPadConv2DShape(op):
- """Shape function for FusedResizeAndPadConv2D op."""
- return _CommonFusedConvCalculations(op, False)
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
ops.RegisterShape("MaxPoolWithArgmax")(common_shapes.call_cpp_shape_fn)