aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-11 16:04:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-11 16:10:33 -0800
commitf27dea2f016baf09040bf5aec705511486a3f205 (patch)
treea84f9faa617b73d903b8c200997d7779f67b7f36
parent52fe9dce7a21007c4acd0254bd9e24021c291acb (diff)
Enable C++ shape fn from python for more quantized ops.
Change: 138929922
-rw-r--r--tensorflow/core/ops/nn_ops_test.cc32
-rw-r--r--tensorflow/python/ops/array_ops.py54
2 files changed, 38 insertions, 48 deletions
diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc
index 360085c8bb..3618769dc0 100644
--- a/tensorflow/core/ops/nn_ops_test.cc
+++ b/tensorflow/core/ops/nn_ops_test.cc
@@ -151,6 +151,38 @@ TEST(NNOpsTest, BatchNormWithGlobalNormalization_ShapeFn) {
"[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0|d3_0|d4_0]");
}
+TEST(NNOpsTest, QuantizedBatchNormWithGlobalNormalization_ShapeFn) {
+ // These are the same tests as BatchNormWithGlobalNormalization tests, but
+ // with extra scalar inputs and outputs for the mins and maxes.
+
+ ShapeInferenceTestOp op("QuantizedBatchNormWithGlobalNormalization");
+
+ // Test rank errors.
+ INFER_ERROR("Shape must be rank 4 but is rank 3", op,
+ "[1,2,3];?;?;?;?;?;?;?;?;?;?;?;?;?;?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op,
+ "?;?;?;[1,2,3];?;?;?;?;?;?;?;?;?;?;?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op,
+ "?;?;?;?;?;?;[1,2,3];?;?;?;?;?;?;?;?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op,
+ "?;?;?;?;?;?;?;?;?;[1,2,3];?;?;?;?;?");
+ INFER_ERROR("Shape must be rank 1 but is rank 3", op,
+ "?;?;?;?;?;?;?;?;?;?;?;?;[1,2,3];?;?");
+
+ // last dim of first input is merged with the single dim in other 4 inputs.
+ INFER_OK(op, "?;[];[];?;[];[];?;[];[];?;[];[];?;[];[]", "[?,?,?,?];[];[]");
+ INFER_OK(op, "?;[];[];[1];[];[];?;[];[];?;[];[];?;[];[]",
+ "[?,?,?,d3_0];[];[]");
+ INFER_OK(op, "?;[];[];?;[];[];[1];[];[];?;[];[];?;[];[]",
+ "[?,?,?,d6_0];[];[]");
+ INFER_OK(op, "?;[];[];?;[];[];?;[];[];[1];[];[];?;[];[]",
+ "[?,?,?,d9_0];[];[]");
+ INFER_OK(op, "?;[];[];?;[];[];?;[];[];?;[];[];[1];[];[]",
+ "[?,?,?,d12_0];[];[]");
+ INFER_OK(op, "[1,2,3,4];[];[];[4];[];[];[4];[];[];[4];[];[];[4];[];[]",
+ "[d0_0,d0_1,d0_2,d0_3|d3_0|d6_0|d9_0|d12_0];[];[]");
+}
+
TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) {
ShapeInferenceTestOp op("BatchNormWithGlobalNormalizationGrad");
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 1b8045f8bf..66160f5717 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1861,49 +1861,6 @@ def _DelegateReshapeShape(op):
return common_shapes.call_cpp_shape_fn(op, input_tensors_as_shapes_needed=[1])
-def _ReshapeShape(op):
- """Shape function for Reshape op."""
- input_shape = op.inputs[0].get_shape()
- if input_shape.ndims is not None:
- num_elements = tensor_shape.Dimension(1)
- for dim in input_shape.dims:
- num_elements *= dim
- else:
- num_elements = tensor_shape.Dimension(None)
- new_shape = tensor_util.constant_value_as_shape(op.inputs[1])
- if new_shape.ndims is None:
- # We have no information about the shape of the output.
- return [new_shape]
- if None not in new_shape.as_list():
- # The new shape is fully defined.
- if (num_elements.value is not None
- and num_elements.value != np.prod(new_shape)):
- raise ValueError(
- "Cannot reshape a tensor with %d elements to shape %s (%d elements)"
- % (num_elements.value, new_shape, np.prod(new_shape)))
- elif num_elements.value is not None:
- # We know the number of elements, so we can calculate the missing
- # dimension in the new_shape.
- known_elements = 1
- unknown_indices = []
- for i, dim in enumerate(new_shape):
- if dim.value is None:
- unknown_indices.append(i)
- else:
- known_elements *= dim.value
- if known_elements != 0:
- if num_elements % known_elements != 0:
- raise ValueError("input has %s elements, which isn't divisible by %d" %
- (num_elements, known_elements))
- if len(unknown_indices) == 1:
- unknown_index = unknown_indices[0]
- new_shape = new_shape.merge_with(
- new_shape[:unknown_index].concatenate(
- [num_elements // known_elements]).concatenate(
- new_shape[unknown_index+1:]))
- return [new_shape]
-
-
ops.RegisterShape("BroadcastGradientArgs")(common_shapes.call_cpp_shape_fn)
@@ -2592,12 +2549,13 @@ def where(condition, x=None, y=None, name=None):
@ops.RegisterShape("QuantizedReshape")
-def _QuantizedReshapeShape(op):
- return _ReshapeShape(op) + [tensor_shape.scalar(), tensor_shape.scalar()]
+def _DelegateQuantizedReshapeShape(op):
+ return common_shapes.call_cpp_shape_fn(
+ op, input_tensors_as_shapes_needed=[1])
-# TODO(cwhipkey): Verify and enable shape functions for these.
-ops.RegisterShape("QuantizeV2")(None)
-ops.RegisterShape("QuantizedBatchNormWithGlobalNormalization")(None)
+ops.RegisterShape("QuantizeV2")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("QuantizedBatchNormWithGlobalNormalization")(
+ common_shapes.call_cpp_shape_fn)
ops.RegisterShape("QuantizedConcat")(common_shapes.call_cpp_shape_fn)