diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-11 16:09:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-11 16:13:44 -0700 |
commit | ab7f22de6aba356fae564e3e8bbb0beb9a98acb4 (patch) | |
tree | 7404f15441a50aaed92659338c96898d54ed5cef /tensorflow/contrib/fused_conv | |
parent | 3a98035fa8fe8d02960c605e210fbf8af2d14516 (diff) |
Move FusedConvBiasActivationShape out of common_shape_fns.cc to a lambda inside the op.
PiperOrigin-RevId: 168300911
Diffstat (limited to 'tensorflow/contrib/fused_conv')
-rw-r--r-- | tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc | 48 |
1 files changed, 47 insertions, 1 deletions
diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 48f058b4c5..c9d0e1f41c 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -52,7 +52,53 @@ REGISTER_OP("FusedConv2DBiasActivation") .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") .Attr("activation_mode: {'Relu'} = 'Relu'") - .SetShapeFn(shape_inference::FusedConvBiasActivationShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + using shape_inference::ShapeHandle; + using shape_inference::DimensionHandle; + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + + string data_format_str, filter_format_str; + TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); + TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str)); + + TensorFormat data_format; + FormatFromString(data_format_str, &data_format); + FilterTensorFormat filter_format; + FilterFormatFromString(filter_format_str, &filter_format); + + constexpr int num_spatial_dims = 2; + const int rank = + GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); + ShapeHandle filter_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); + + DimensionHandle output_depth_dim = + c->Dim(filter_shape, + GetFilterDimIndex<num_spatial_dims>(filter_format, 'O')); + int64 output_depth_dim_val = c->Value(output_depth_dim); + + ShapeHandle bias_shape; + // Bias should be a 1-D tensor. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape)); + DimensionHandle bias_dim = c->Dim(bias_shape, 0); + int64 bias_dim_val = c->Value(bias_dim); + + if (output_depth_dim_val != bias_dim_val) { + return errors::InvalidArgument( + "Output depth dimension (", output_depth_dim_val, + ") and bias dimension (", bias_dim_val, ") do not match."); + } + + // Check side input shape matches the output shape. + ShapeHandle side_input_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape)); + if (c->Rank(side_input_shape) > 1) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused)); + } + + return Status::OK(); + }) .Doc(R"doc( Computes a fused kernel which implements: 2-D convolution, adds side input, with separate scaling on convolution and side inputs, then adds bias and |