diff options
author | 2016-07-25 16:16:02 -0800 | |
---|---|---|
committer | 2016-07-25 17:32:58 -0700 | |
commit | 2dc33a83ad0c6a394e35fdaad769f418cc376fff (patch) | |
tree | c6fe3f144086c03a185536b7a63f557b84fe9914 /tensorflow/contrib/quantization | |
parent | 81a2892e6f6906c8a1c6e27a7607071328bba8c3 (diff) |
Add existing common shape function uses to MatMuls, Conv2Ds, AvgPool.
Since common shape functions are already tested, and the additions
here are pretty straight-forward extensions that use the common
shapes, I've elided tests for them, but could add them if we thought
it was useful.
Change: 128418673
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r-- | tensorflow/contrib/quantization/ops/math_ops.cc | 17 | ||||
-rw-r--r-- | tensorflow/contrib/quantization/ops/nn_ops.cc | 14 |
2 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantization/ops/math_ops.cc b/tensorflow/contrib/quantization/ops/math_ops.cc index 204b544972..6bc408531a 100644 --- a/tensorflow/contrib/quantization/ops/math_ops.cc +++ b/tensorflow/contrib/quantization/ops/math_ops.cc @@ -13,11 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +using shape_inference::InferenceContext; +using shape_inference::Shape; + REGISTER_OP("QuantizedMatMul") .Input("a: T1") .Input("b: T2") @@ -33,6 +38,18 @@ REGISTER_OP("QuantizedMatMul") .Attr("Toutput: quantizedtype = DT_QINT32") .Attr("transpose_a: bool = false") .Attr("transpose_b: bool = false") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Perform a quantized matrix multiplication of `a` by the matrix `b`. diff --git a/tensorflow/contrib/quantization/ops/nn_ops.cc b/tensorflow/contrib/quantization/ops/nn_ops.cc index ef99be0d48..fd12d155db 100644 --- a/tensorflow/contrib/quantization/ops/nn_ops.cc +++ b/tensorflow/contrib/quantization/ops/nn_ops.cc @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/util/padding.h" namespace tensorflow { +using shape_inference::InferenceContext; +using shape_inference::Shape; + REGISTER_OP("QuantizedAvgPool") .Input("input: T") .Input("min_input: float") @@ -30,6 +35,15 @@ REGISTER_OP("QuantizedAvgPool") .Attr("ksize: list(int)") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Produces the average pool of the input tensor for quantized types. |