aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-23 16:35:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 17:54:30 -0700
commitf6bc8cabbd3ac1fb3acc36d3edbdce672cae7d12 (patch)
treea8c81b0269e68f57606052ab5573743f86995be6 /tensorflow/contrib/quantization
parentade1672d60d861c58e1930e93a1b396b22e7a4d9 (diff)
Add shape_inference::ShapeHandle and shape_inference::DimensionHandle to
replace uses of const Shape* and const Dimension*. This change only adds a typedef and updates references. A later change will make DimensionHandle and ShapeHandle real types instead of typedefs (to further hide the pointer access). Change: 131118981
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/ops/array_ops.cc8
-rw-r--r--tensorflow/contrib/quantization/ops/math_ops.cc6
-rw-r--r--tensorflow/contrib/quantization/ops/nn_ops.cc26
3 files changed, 20 insertions, 20 deletions
diff --git a/tensorflow/contrib/quantization/ops/array_ops.cc b/tensorflow/contrib/quantization/ops/array_ops.cc
index 7dd64f82e0..ff636c7957 100644
--- a/tensorflow/contrib/quantization/ops/array_ops.cc
+++ b/tensorflow/contrib/quantization/ops/array_ops.cc
@@ -20,7 +20,7 @@ limitations under the License.
namespace tensorflow {
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
REGISTER_OP("QuantizeV2")
.Input("input: float")
@@ -33,7 +33,7 @@ REGISTER_OP("QuantizeV2")
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST'} = 'MIN_COMBINED'")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
c->set_output(1, c->Scalar());
@@ -110,7 +110,7 @@ REGISTER_OP("Dequantize")
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST'} = 'MIN_COMBINED'")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
return Status::OK();
@@ -168,7 +168,7 @@ REGISTER_OP("QuantizedConcat")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c));
- const Shape* unused;
+ ShapeHandle unused;
for (int i = 2; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
}
diff --git a/tensorflow/contrib/quantization/ops/math_ops.cc b/tensorflow/contrib/quantization/ops/math_ops.cc
index ed0930c2d6..93bb283630 100644
--- a/tensorflow/contrib/quantization/ops/math_ops.cc
+++ b/tensorflow/contrib/quantization/ops/math_ops.cc
@@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
REGISTER_OP("QuantizedMatMul")
.Input("a: T1")
@@ -40,7 +40,7 @@ REGISTER_OP("QuantizedMatMul")
.Attr("transpose_b: bool = false")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
@@ -82,7 +82,7 @@ REGISTER_OP("QuantizeDownAndShrinkRange")
.Attr("out_type: quantizedtype")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
c->set_output(1, c->Scalar());
diff --git a/tensorflow/contrib/quantization/ops/nn_ops.cc b/tensorflow/contrib/quantization/ops/nn_ops.cc
index c33f318c6e..720377043d 100644
--- a/tensorflow/contrib/quantization/ops/nn_ops.cc
+++ b/tensorflow/contrib/quantization/ops/nn_ops.cc
@@ -21,9 +21,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::Dimension;
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
-using shape_inference::Shape;
+using shape_inference::ShapeHandle;
REGISTER_OP("QuantizedAvgPool")
.Input("input: T")
@@ -38,7 +38,7 @@ REGISTER_OP("QuantizedAvgPool")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
c->set_output(1, c->Scalar());
@@ -76,7 +76,7 @@ REGISTER_OP("QuantizedBiasAdd")
.Attr("out_type: quantizedtype")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
@@ -117,7 +117,7 @@ REGISTER_OP("QuantizedConv2D")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
@@ -159,7 +159,7 @@ REGISTER_OP("QuantizedMaxPool")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
c->set_output(1, c->Scalar());
@@ -193,7 +193,7 @@ REGISTER_OP("QuantizedRelu")
.Attr("out_type: quantizedtype = DT_QUINT8")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
c->set_output(1, c->Scalar());
@@ -222,7 +222,7 @@ REGISTER_OP("QuantizedRelu6")
.Attr("out_type: quantizedtype = DT_QUINT8")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
c->set_output(1, c->Scalar());
@@ -252,7 +252,7 @@ REGISTER_OP("QuantizedReluX")
.Attr("out_type: quantizedtype = DT_QUINT8")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
- const Shape* unused;
+ ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
c->set_output(1, c->Scalar());
@@ -294,17 +294,17 @@ REGISTER_OP("QuantizedBatchNormWithGlobalNormalization")
.Attr("variance_epsilon: float")
.Attr("scale_after_normalization: bool")
.SetShapeFn([](InferenceContext* c) {
- const Shape* input;
+ ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
- const Dimension* last_dim = c->Dim(input, 3);
+ DimensionHandle last_dim = c->Dim(input, 3);
for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
- const Shape* vec;
+ ShapeHandle vec;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec));
TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
}
- const Shape* out;
+ ShapeHandle out;
TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
c->set_output(0, out);
c->set_output(1, c->Scalar());