diff options
author | 2016-11-11 16:04:49 -0800 | |
---|---|---|
committer | 2016-11-11 16:10:33 -0800 | |
commit | f27dea2f016baf09040bf5aec705511486a3f205 (patch) | |
tree | a84f9faa617b73d903b8c200997d7779f67b7f36 | |
parent | 52fe9dce7a21007c4acd0254bd9e24021c291acb (diff) |
Enable C++ shape fn from python for more quantized ops.
Change: 138929922
-rw-r--r-- | tensorflow/core/ops/nn_ops_test.cc | 32 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 54 |
2 files changed, 38 insertions, 48 deletions
diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc index 360085c8bb..3618769dc0 100644 --- a/tensorflow/core/ops/nn_ops_test.cc +++ b/tensorflow/core/ops/nn_ops_test.cc @@ -151,6 +151,38 @@ TEST(NNOpsTest, BatchNormWithGlobalNormalization_ShapeFn) { "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0|d3_0|d4_0]"); } +TEST(NNOpsTest, QuantizedBatchNormWithGlobalNormalization_ShapeFn) { + // These are the same tests as BatchNormWithGlobalNormalization tests, but + // with extra scalar inputs and outputs for the mins and maxes. + + ShapeInferenceTestOp op("QuantizedBatchNormWithGlobalNormalization"); + + // Test rank errors. + INFER_ERROR("Shape must be rank 4 but is rank 3", op, + "[1,2,3];?;?;?;?;?;?;?;?;?;?;?;?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, + "?;?;?;[1,2,3];?;?;?;?;?;?;?;?;?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, + "?;?;?;?;?;?;[1,2,3];?;?;?;?;?;?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, + "?;?;?;?;?;?;?;?;?;[1,2,3];?;?;?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, + "?;?;?;?;?;?;?;?;?;?;?;?;[1,2,3];?;?"); + + // last dim of first input is merged with the single dim in other 4 inputs. + INFER_OK(op, "?;[];[];?;[];[];?;[];[];?;[];[];?;[];[]", "[?,?,?,?];[];[]"); + INFER_OK(op, "?;[];[];[1];[];[];?;[];[];?;[];[];?;[];[]", + "[?,?,?,d3_0];[];[]"); + INFER_OK(op, "?;[];[];?;[];[];[1];[];[];?;[];[];?;[];[]", + "[?,?,?,d6_0];[];[]"); + INFER_OK(op, "?;[];[];?;[];[];?;[];[];[1];[];[];?;[];[]", + "[?,?,?,d9_0];[];[]"); + INFER_OK(op, "?;[];[];?;[];[];?;[];[];?;[];[];[1];[];[]", + "[?,?,?,d12_0];[];[]"); + INFER_OK(op, "[1,2,3,4];[];[];[4];[];[];[4];[];[];[4];[];[];[4];[];[]", + "[d0_0,d0_1,d0_2,d0_3|d3_0|d6_0|d9_0|d12_0];[];[]"); +} + TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) { ShapeInferenceTestOp op("BatchNormWithGlobalNormalizationGrad"); diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 1b8045f8bf..66160f5717 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1861,49 +1861,6 @@ def _DelegateReshapeShape(op): return common_shapes.call_cpp_shape_fn(op, input_tensors_as_shapes_needed=[1]) -def _ReshapeShape(op): - """Shape function for Reshape op.""" - input_shape = op.inputs[0].get_shape() - if input_shape.ndims is not None: - num_elements = tensor_shape.Dimension(1) - for dim in input_shape.dims: - num_elements *= dim - else: - num_elements = tensor_shape.Dimension(None) - new_shape = tensor_util.constant_value_as_shape(op.inputs[1]) - if new_shape.ndims is None: - # We have no information about the shape of the output. - return [new_shape] - if None not in new_shape.as_list(): - # The new shape is fully defined. - if (num_elements.value is not None - and num_elements.value != np.prod(new_shape)): - raise ValueError( - "Cannot reshape a tensor with %d elements to shape %s (%d elements)" - % (num_elements.value, new_shape, np.prod(new_shape))) - elif num_elements.value is not None: - # We know the number of elements, so we can calculate the missing - # dimension in the new_shape. - known_elements = 1 - unknown_indices = [] - for i, dim in enumerate(new_shape): - if dim.value is None: - unknown_indices.append(i) - else: - known_elements *= dim.value - if known_elements != 0: - if num_elements % known_elements != 0: - raise ValueError("input has %s elements, which isn't divisible by %d" % - (num_elements, known_elements)) - if len(unknown_indices) == 1: - unknown_index = unknown_indices[0] - new_shape = new_shape.merge_with( - new_shape[:unknown_index].concatenate( - [num_elements // known_elements]).concatenate( - new_shape[unknown_index+1:])) - return [new_shape] - - ops.RegisterShape("BroadcastGradientArgs")(common_shapes.call_cpp_shape_fn) @@ -2592,12 +2549,13 @@ def where(condition, x=None, y=None, name=None): @ops.RegisterShape("QuantizedReshape") -def _QuantizedReshapeShape(op): - return _ReshapeShape(op) + [tensor_shape.scalar(), tensor_shape.scalar()] +def _DelegateQuantizedReshapeShape(op): + return common_shapes.call_cpp_shape_fn( + op, input_tensors_as_shapes_needed=[1]) -# TODO(cwhipkey): Verify and enable shape functions for these. -ops.RegisterShape("QuantizeV2")(None) -ops.RegisterShape("QuantizedBatchNormWithGlobalNormalization")(None) +ops.RegisterShape("QuantizeV2")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("QuantizedBatchNormWithGlobalNormalization")( + common_shapes.call_cpp_shape_fn) ops.RegisterShape("QuantizedConcat")(common_shapes.call_cpp_shape_fn) |