aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/fused_conv
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-11 16:09:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-11 16:13:44 -0700
commitab7f22de6aba356fae564e3e8bbb0beb9a98acb4 (patch)
tree7404f15441a50aaed92659338c96898d54ed5cef /tensorflow/contrib/fused_conv
parent3a98035fa8fe8d02960c605e210fbf8af2d14516 (diff)
Move FusedConvBiasActivationShape out of common_shape_fns.cc to a lambda inside the op.
PiperOrigin-RevId: 168300911
Diffstat (limited to 'tensorflow/contrib/fused_conv')
-rw-r--r--tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc48
1 files changed, 47 insertions, 1 deletions
diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
index 48f058b4c5..c9d0e1f41c 100644
--- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
@@ -52,7 +52,53 @@ REGISTER_OP("FusedConv2DBiasActivation")
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
.Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'")
.Attr("activation_mode: {'Relu'} = 'Relu'")
- .SetShapeFn(shape_inference::FusedConvBiasActivationShape)
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ using shape_inference::ShapeHandle;
+ using shape_inference::DimensionHandle;
+ TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
+
+ string data_format_str, filter_format_str;
+ TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
+ TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str));
+
+ TensorFormat data_format;
+ FormatFromString(data_format_str, &data_format);
+ FilterTensorFormat filter_format;
+ FilterFormatFromString(filter_format_str, &filter_format);
+
+ constexpr int num_spatial_dims = 2;
+ const int rank =
+ GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
+ ShapeHandle filter_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
+
+ DimensionHandle output_depth_dim =
+ c->Dim(filter_shape,
+ GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
+ int64 output_depth_dim_val = c->Value(output_depth_dim);
+
+ ShapeHandle bias_shape;
+ // Bias should be a 1-D tensor.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape));
+ DimensionHandle bias_dim = c->Dim(bias_shape, 0);
+ int64 bias_dim_val = c->Value(bias_dim);
+
+ if (output_depth_dim_val != bias_dim_val) {
+ return errors::InvalidArgument(
+ "Output depth dimension (", output_depth_dim_val,
+ ") and bias dimension (", bias_dim_val, ") do not match.");
+ }
+
+ // Check side input shape matches the output shape.
+ ShapeHandle side_input_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape));
+ if (c->Rank(side_input_shape) > 1) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused));
+ }
+
+ return Status::OK();
+ })
.Doc(R"doc(
Computes a fused kernel which implements: 2-D convolution, adds side input,
with separate scaling on convolution and side inputs, then adds bias and