aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-07-25 16:16:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-25 17:32:58 -0700
commit2dc33a83ad0c6a394e35fdaad769f418cc376fff (patch)
treec6fe3f144086c03a185536b7a63f557b84fe9914 /tensorflow/contrib/quantization
parent81a2892e6f6906c8a1c6e27a7607071328bba8c3 (diff)
Add existing common shape function uses to MatMuls, Conv2Ds, AvgPool.
Since common shape functions are already tested, and the additions here are pretty straight-forward extensions that use the common shapes, I've elided tests for them, but could add them if we thought it was useful. Change: 128418673
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/ops/math_ops.cc17
-rw-r--r--tensorflow/contrib/quantization/ops/nn_ops.cc14
2 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantization/ops/math_ops.cc b/tensorflow/contrib/quantization/ops/math_ops.cc
index 204b544972..6bc408531a 100644
--- a/tensorflow/contrib/quantization/ops/math_ops.cc
+++ b/tensorflow/contrib/quantization/ops/math_ops.cc
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+using shape_inference::InferenceContext;
+using shape_inference::Shape;
+
REGISTER_OP("QuantizedMatMul")
.Input("a: T1")
.Input("b: T2")
@@ -33,6 +38,18 @@ REGISTER_OP("QuantizedMatMul")
.Attr("Toutput: quantizedtype = DT_QINT32")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
+ const Shape* unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ })
.Doc(R"doc(
Perform a quantized matrix multiplication of `a` by the matrix `b`.
diff --git a/tensorflow/contrib/quantization/ops/nn_ops.cc b/tensorflow/contrib/quantization/ops/nn_ops.cc
index ef99be0d48..fd12d155db 100644
--- a/tensorflow/contrib/quantization/ops/nn_ops.cc
+++ b/tensorflow/contrib/quantization/ops/nn_ops.cc
@@ -13,12 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/util/padding.h"
namespace tensorflow {
+using shape_inference::InferenceContext;
+using shape_inference::Shape;
+
REGISTER_OP("QuantizedAvgPool")
.Input("input: T")
.Input("min_input: float")
@@ -30,6 +35,15 @@ REGISTER_OP("QuantizedAvgPool")
.Attr("ksize: list(int)")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c));
+ const Shape* unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ })
.Doc(R"doc(
Produces the average pool of the input tensor for quantized types.