diff options
author | 2016-09-08 13:02:19 -0800 | |
---|---|---|
committer | 2016-09-08 14:17:40 -0700 | |
commit | c3a30a230f47a8ca8f5dd1dd79c63229ce1349b8 (patch) | |
tree | 53468947feb0b30a865e0b1210a5d734a09c77bf | |
parent | 57d6a3ee564e89cf8318b5d2f3b851888f21b86e (diff) |
Switch nn_ops shape fns to delegate to C++, for all that have a C++
implementation (fractional pool ones don't yet).
Change BiasAdd functions to require only rank 3, not 4, for NHWC. This matches
the behavior of GetBiasValueDims in bias_op.cc.
Removed unused functions common_shapes.bias_add_shape and
common_shapes.bias_add_grad_shape.
Change: 132597521
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns_test.cc | 29 | ||||
-rw-r--r-- | tensorflow/python/framework/common_shapes.py | 39 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/pooling_ops_test.py | 26 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/topk_op_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 174 |
6 files changed, 51 insertions, 223 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index fa224c8d6a..8df470aa22 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -135,7 +135,7 @@ Status BiasAddShape(shape_inference::InferenceContext* c) { Status s = c->GetAttr("data_format", &data_format); if (s.ok() && data_format == "NCHW") { - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); } else { TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); } @@ -193,7 +193,7 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) { Status s = c->GetAttr("data_format", &data_format); if (s.ok() && data_format == "NCHW") { - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); c->set_output(0, c->Vector(c->Dim(input_shape, -3))); } else { TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index f399c1b171..570ac28127 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -243,6 +243,19 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { } { + // NCHW format with input rank 3 + TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd") + .Input("a", 0, DT_FLOAT) + .Input("b", 0, DT_FLOAT) + .Attr("data_format", "NCHW") + .Finalize(&def)); + InferenceContext c(&def, op_def, {"[10,11,12]", "[10]"}, {}); + TF_EXPECT_OK(BiasAddShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ("[10,11,12]", c.DebugString(output)); + } + + { // Input rank not high enough InferenceContext c(&def, op_def, {"[3]", "[3]"}, {}); EXPECT_FALSE(BiasAddShape(&c).ok()); @@ -256,7 +269,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Attr("data_format", "NCHW") .Finalize(&def)); // NCHW format - InferenceContext c(&def, op_def, {"[2,3,4]", "[3]"}, {}); + InferenceContext c(&def, op_def, {"[2,3]", "[3]"}, {}); EXPECT_FALSE(BiasAddShape(&c).ok()); } } @@ -314,6 +327,18 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { } { + // NCHW format with input rank 3 + TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad") + .Input("a", 0, DT_FLOAT) + .Attr("data_format", "NCHW") + .Finalize(&def)); + InferenceContext c(&def, op_def, {"[10,11,12]"}, {}); + TF_EXPECT_OK(BiasAddGradShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(10, c.Value(c.Dim(output, 0))); + } + + { // Input rank not high enough InferenceContext c(&def, op_def, {"[3]"}, {}); EXPECT_FALSE(BiasAddGradShape(&c).ok()); @@ -326,7 +351,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Attr("data_format", "NCHW") .Finalize(&def)); // NCHW format - InferenceContext c(&def, op_def, {"[2,3,4]"}, {}); + InferenceContext c(&def, op_def, {"[2,3]"}, {}); EXPECT_FALSE(BiasAddGradShape(&c).ok()); } } diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py index 8c5251bdf1..898c993f7d 100644 --- a/tensorflow/python/framework/common_shapes.py +++ b/tensorflow/python/framework/common_shapes.py @@ -102,45 +102,6 @@ def matmul_shape(op): return [tensor_shape.TensorShape([output_rows, output_cols])] -def bias_add_shape(op): - """Shape function for a BiasAdd op.""" - input_shape = op.inputs[0].get_shape().with_rank_at_least(2) - bias_shape = op.inputs[1].get_shape().with_rank(1) - if input_shape.ndims is not None: - # Output has the same shape as input, and matches the length of - # bias in its bias dimension. - try: - data_format = op.get_attr("data_format") - except ValueError: - data_format = None - if data_format == b"NCHW": - # Merge the length of bias_shape into the third-to-last dimension. - output_shape = input_shape[0:-3].concatenate(input_shape[-3].merge_with( - bias_shape[0])).concatenate(input_shape[-2:]) - else: - output_shape = input_shape[0:-1].concatenate(input_shape[-1].merge_with( - bias_shape[0])) - else: - output_shape = tensor_shape.unknown_shape() - return [output_shape] - - -def bias_add_grad_shape(op): - """Shape function for a BiasAddGrad op.""" - input_shape = op.inputs[0].get_shape().with_rank_at_least(2) - try: - data_format = op.get_attr("data_format") - except ValueError: - data_format = None - - if data_format == b"NCHW": - output_shape = input_shape[-3] - else: - output_shape = input_shape[-1] - - return [output_shape] - - def get_conv_output_size(input_size, filter_size, strides, padding_type): """Returns the spatial size of a n-d convolution/pooling output.""" input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size]) diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 185fdc14bc..9a1f96b6fe 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -929,30 +929,12 @@ class PoolingTest(tf.test.TestCase): pool_func(tf.placeholder(tf.float32, shape=[1, 3]), ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1], padding="SAME") - # Illegal strides. - with self.assertRaisesRegexp(ValueError, "strides in the batch"): - tf.nn.max_pool_with_argmax( - tf.placeholder(tf.float32), - ksize=[1, 1, 1, 1], - strides=[2, 1, 1, 1], - padding="SAME") - - # Filter larger than input. - for pool_func in [tf.nn.max_pool_with_argmax]: - with self.assertRaisesRegexp(ValueError, - "Filter must not be larger than the input"): - pool_func(tf.placeholder(tf.float32, - shape=[32, 20, 20, 3]), - ksize=[1, 20, 21, 1], strides=[1, 1, 1, 1], padding="SAME") - with self.assertRaisesRegexp(ValueError, - "Filter must not be larger than the input"): - pool_func(tf.placeholder(tf.float32, - shape=[32, 20, 20, 3]), - ksize=[1, 21, 20, 1], strides=[1, 1, 1, 1], padding="SAME") - def testOpEdgeCases(self): with self.test_session() as sess: - for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]: + pool_funcs = [tf.nn.max_pool, tf.nn.avg_pool] + if tf.test.is_gpu_available(): + pool_funcs.append(tf.nn.max_pool_with_argmax) + for pool_func in pool_funcs: # Illegal strides. with self.assertRaisesRegexp( tf.errors.UnimplementedError, diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py index 4550a2831c..1f9632ef03 100644 --- a/tensorflow/python/kernel_tests/topk_op_test.py +++ b/tensorflow/python/kernel_tests/topk_op_test.py @@ -81,7 +81,7 @@ class TopKTest(tf.test.TestCase): def testKTooLarge(self): inputs = [[0.1, 0.2], [0.3, 0.4]] with self.assertRaisesRegexp( - ValueError, r"input.shape \(2, 2\) must have last dimension >= k = 4"): + ValueError, r"must have last dimension >= k = 4"): tf.nn.top_k(inputs, 4) def testTopKGradients(self): diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index c27e594381..9560ffe66d 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -393,9 +393,10 @@ def bias_add(value, bias, data_format=None, name=None): return gen_nn_ops._bias_add(value, bias, data_format=data_format, name=name) -ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape) - -ops.RegisterShape("BiasAddGrad")(common_shapes.bias_add_grad_shape) +ops.RegisterShape("BiasAddV1")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BiasAdd")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BiasAddGradV1")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BiasAddGrad")(common_shapes.call_cpp_shape_fn) # pylint: disable=protected-access @@ -426,11 +427,6 @@ def bias_add_v1(value, bias, name=None): return gen_nn_ops._bias_add_v1(value, bias, name=name) -ops.RegisterShape("BiasAddV1")(common_shapes.bias_add_shape) - -ops.RegisterShape("BiasAddGradV1")(common_shapes.bias_add_grad_shape) - - def crelu(features, name=None): """Computes Concatenated ReLU. @@ -866,23 +862,12 @@ ops.RegisterShape("LRNGrad")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Softmax")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("LogSoftmax")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("InTopK")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TopK")(common_shapes.call_cpp_shape_fn) -@ops.RegisterShape("TopK") @ops.RegisterShape("TopKV2") -def _TopKShape(op): - """Shape function for TopK and TopKV2 ops.""" - input_shape = op.inputs[0].get_shape().with_rank_at_least(1) - if len(op.inputs) >= 2: - k = tensor_util.constant_value(op.inputs[1]) - else: - k = op.get_attr("k") - last = input_shape[-1].value - if last is not None and k is not None and last < k: - raise ValueError("input.shape %s must have last dimension >= k = %d" % - (input_shape, k)) - output_shape = input_shape[:-1].concatenate([k]) - return [output_shape, output_shape] +def _TopKV2Shape(op): + return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1]) ops.RegisterShape("BatchNormWithGlobalNormalization")( @@ -960,24 +945,12 @@ def _FusedResizeAndPadConv2DShape(op): return [tensor_shape.TensorShape(output_shape)] -@ops.RegisterShape("MaxPoolWithArgmax") -def _MaxPoolWithArgMaxShape(op): - """Shape function for MaxPoolWithArgmax op.""" - return common_shapes.max_pool_shape(op) * 2 +ops.RegisterShape("MaxPoolWithArgmax")(common_shapes.call_cpp_shape_fn) @ops.RegisterShape("AvgPoolGrad") def _AvgPoolGradShape(op): - """Shape function for the AvgPoolGrad op.""" - orig_input_shape = tensor_util.constant_value(op.inputs[0]) - if orig_input_shape is not None: - return [tensor_shape.TensorShape(orig_input_shape.tolist())] - else: - # NOTE(mrry): We could in principle work out the shape from the - # gradients and the attrs, but if we do not know orig_input_shape - # statically, then we are unlikely to know the shape of the - # gradients either. - return [tensor_shape.unknown_shape(ndims=4)] + return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0]) @ops.RegisterShape("FractionalMaxPool") @@ -1015,50 +988,22 @@ def _fractional_avg_pool_grad_shape(op): @ops.RegisterShape("Conv2DBackpropFilter") def _Conv2DBackpropFilterShape(op): - """Shape function for the Conv2DBackpropFilter op.""" - filter_shape = tensor_util.constant_value(op.inputs[1]) - if filter_shape is not None: - return [tensor_shape.TensorShape(filter_shape.tolist())] - else: - # NOTE(mrry): We could in principle work out the shape from the - # gradients and the attrs, but if we do not know filter_shape - # statically, then we are unlikely to know the shape of the - # gradients either. - return [tensor_shape.unknown_shape(ndims=4)] + return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1]) @ops.RegisterShape("Conv2DBackpropInput") def _Conv2DBackpropInputShape(op): - """Shape function for the Conv2DBackpropInput op.""" - input_shape = tensor_util.constant_value(op.inputs[0]) - if input_shape is not None: - return [tensor_shape.TensorShape(input_shape.tolist())] - else: - # NOTE(mrry): We could in principle work out the shape from the - # gradients and the attrs, but if we do not know input_shape - # statically, then we are unlikely to know the shape of the - # gradients either. - return [tensor_shape.unknown_shape(ndims=4)] + return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0]) @ops.RegisterShape("DepthwiseConv2dNativeBackpropFilter") def _DepthwiseConv2dNativeBackpropFilterShape(op): - """Shape function for the DepthwiseConv2dNativeBackpropFilter op.""" - filter_shape = tensor_util.constant_value(op.inputs[1]) - if filter_shape is not None: - return [tensor_shape.TensorShape(filter_shape.tolist())] - else: - return [tensor_shape.unknown_shape(ndims=4)] + return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1]) @ops.RegisterShape("DepthwiseConv2dNativeBackpropInput") def _DepthwiseConv2dNativeBackpropInputShape(op): - """Shape function for the DepthwiseConv2dNativeBackpropInput op.""" - input_shape = tensor_util.constant_value(op.inputs[0]) - if input_shape is not None: - return [tensor_shape.TensorShape(input_shape.tolist())] - else: - return [tensor_shape.unknown_shape(ndims=4)] + return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0]) ops.RegisterShape("MaxPoolGrad")(common_shapes.call_cpp_shape_fn) @@ -1136,55 +1081,9 @@ def _calc_depthwise_conv_weight_params(graph, node): filter_channel_multiplier)) -@ops.RegisterShape("Conv3D") -def _Conv3DShape(op): - """Shape function for Conv3D.""" - input_shape = op.inputs[0].get_shape().with_rank(5) - filter_shape = op.inputs[1].get_shape().with_rank(5) - - batch_size = input_shape[0] - out_channels = filter_shape[4] - # Check that the input number of channels is compatible between - # input data and filter size. - input_shape[4].assert_is_compatible_with(filter_shape[3]) - - stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides") - assert stride_b == 1 - assert stride_d == 1 - - padding_type = op.get_attr("padding") - out_planes, out_rows, out_cols = common_shapes.get_conv_output_size( - input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c), - padding_type) - - return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols, - out_channels])] - - -@ops.RegisterShape("MaxPool3D") -@ops.RegisterShape("AvgPool3D") -def _Pool3DShape(op): - """Shape function for Max/AvgPool3D.""" - input_shape = op.inputs[0].get_shape().with_rank(5) - ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize") - assert ksize_b == 1 - assert ksize_d == 1 - - stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides") - assert stride_b == 1 - assert stride_d == 1 - - batch_size = input_shape[0] - channels = input_shape[4] - - padding = op.get_attr("padding") - out_planes, out_rows, out_cols = common_shapes.get_conv_output_size( - input_shape[1:4], (ksize_p, ksize_r, ksize_c), - (stride_p, stride_r, stride_c), padding) - return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols, - channels])] - - +ops.RegisterShape("Conv3D")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("MaxPool3D")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("AvgPool3D")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Conv3DBackpropFilter")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Conv3DBackpropInput")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Conv3DBackpropFilterV2")(common_shapes.call_cpp_shape_fn) @@ -1392,46 +1291,7 @@ def conv1d(value, filters, stride, padding, return array_ops.squeeze(result, [1]) -@ops.RegisterShape("Dilation2D") -def _Dilation2DShape(op): - """Shape function for Dilation2D op.""" - input_shape = op.inputs[0].get_shape().with_rank(4) - filter_shape = op.inputs[1].get_shape().with_rank(3) - - batch_size = input_shape[0] - in_rows = input_shape[1] - in_cols = input_shape[2] - depth = input_shape[3] - - filter_rows = filter_shape[0] - filter_cols = filter_shape[1] - # Check that the input depths are compatible. - input_shape[3].assert_is_compatible_with(filter_shape[2]) - - stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") - if stride_b != 1 or stride_d != 1: - raise ValueError("Current implementation does not yet support " - "strides in the batch and depth dimensions.") - - rate_b, rate_r, rate_c, rate_d = op.get_attr("rates") - if rate_b != 1 or rate_d != 1: - raise ValueError("Current implementation does not yet support " - "rates in the batch and depth dimensions.") - - filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_r - 1) - filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_c - 1) - - padding = op.get_attr("padding") - out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols, - filter_rows_eff, - filter_cols_eff, - stride_r, stride_c, - padding) - - output_shape = [batch_size, out_rows, out_cols, depth] - return [tensor_shape.TensorShape(output_shape)] - - +ops.RegisterShape("Dilation2D")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Dilation2DBackpropInput")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Dilation2DBackpropFilter")(common_shapes.call_cpp_shape_fn) |