aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-08 13:02:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-08 14:17:40 -0700
commitc3a30a230f47a8ca8f5dd1dd79c63229ce1349b8 (patch)
tree53468947feb0b30a865e0b1210a5d734a09c77bf
parent57d6a3ee564e89cf8318b5d2f3b851888f21b86e (diff)
Switch nn_ops shape fns to delegate to C++, for all that have a C++
implementation (fractional pool ones don't yet). Change BiasAdd functions to require only rank 3, not 4, for NHWC. This matches the behavior of GetBiasValueDims in bias_op.cc. Removed unused functions common_shapes.bias_add_shape and common_shapes.bias_add_grad_shape. Change: 132597521
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc4
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc29
-rw-r--r--tensorflow/python/framework/common_shapes.py39
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py26
-rw-r--r--tensorflow/python/kernel_tests/topk_op_test.py2
-rw-r--r--tensorflow/python/ops/nn_ops.py174
6 files changed, 51 insertions, 223 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index fa224c8d6a..8df470aa22 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -135,7 +135,7 @@ Status BiasAddShape(shape_inference::InferenceContext* c) {
Status s = c->GetAttr("data_format", &data_format);
if (s.ok() && data_format == "NCHW") {
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape));
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
} else {
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
}
@@ -193,7 +193,7 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
Status s = c->GetAttr("data_format", &data_format);
if (s.ok() && data_format == "NCHW") {
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape));
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
c->set_output(0, c->Vector(c->Dim(input_shape, -3)));
} else {
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index f399c1b171..570ac28127 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -243,6 +243,19 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
}
{
+ // NCHW format with input rank 3
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
+ .Input("a", 0, DT_FLOAT)
+ .Input("b", 0, DT_FLOAT)
+ .Attr("data_format", "NCHW")
+ .Finalize(&def));
+ InferenceContext c(&def, op_def, {"[10,11,12]", "[10]"}, {});
+ TF_EXPECT_OK(BiasAddShape(&c));
+ ShapeHandle output = c.output(0);
+ EXPECT_EQ("[10,11,12]", c.DebugString(output));
+ }
+
+ {
// Input rank not high enough
InferenceContext c(&def, op_def, {"[3]", "[3]"}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
@@ -256,7 +269,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
- InferenceContext c(&def, op_def, {"[2,3,4]", "[3]"}, {});
+ InferenceContext c(&def, op_def, {"[2,3]", "[3]"}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
}
@@ -314,6 +327,18 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
}
{
+ // NCHW format with input rank 3
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
+ .Input("a", 0, DT_FLOAT)
+ .Attr("data_format", "NCHW")
+ .Finalize(&def));
+ InferenceContext c(&def, op_def, {"[10,11,12]"}, {});
+ TF_EXPECT_OK(BiasAddGradShape(&c));
+ ShapeHandle output = c.output(0);
+ EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
+ }
+
+ {
// Input rank not high enough
InferenceContext c(&def, op_def, {"[3]"}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
@@ -326,7 +351,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
- InferenceContext c(&def, op_def, {"[2,3,4]"}, {});
+ InferenceContext c(&def, op_def, {"[2,3]"}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
}
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index 8c5251bdf1..898c993f7d 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -102,45 +102,6 @@ def matmul_shape(op):
return [tensor_shape.TensorShape([output_rows, output_cols])]
-def bias_add_shape(op):
- """Shape function for a BiasAdd op."""
- input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
- bias_shape = op.inputs[1].get_shape().with_rank(1)
- if input_shape.ndims is not None:
- # Output has the same shape as input, and matches the length of
- # bias in its bias dimension.
- try:
- data_format = op.get_attr("data_format")
- except ValueError:
- data_format = None
- if data_format == b"NCHW":
- # Merge the length of bias_shape into the third-to-last dimension.
- output_shape = input_shape[0:-3].concatenate(input_shape[-3].merge_with(
- bias_shape[0])).concatenate(input_shape[-2:])
- else:
- output_shape = input_shape[0:-1].concatenate(input_shape[-1].merge_with(
- bias_shape[0]))
- else:
- output_shape = tensor_shape.unknown_shape()
- return [output_shape]
-
-
-def bias_add_grad_shape(op):
- """Shape function for a BiasAddGrad op."""
- input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
- try:
- data_format = op.get_attr("data_format")
- except ValueError:
- data_format = None
-
- if data_format == b"NCHW":
- output_shape = input_shape[-3]
- else:
- output_shape = input_shape[-1]
-
- return [output_shape]
-
-
def get_conv_output_size(input_size, filter_size, strides, padding_type):
"""Returns the spatial size of a n-d convolution/pooling output."""
input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 185fdc14bc..9a1f96b6fe 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -929,30 +929,12 @@ class PoolingTest(tf.test.TestCase):
pool_func(tf.placeholder(tf.float32, shape=[1, 3]),
ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1], padding="SAME")
- # Illegal strides.
- with self.assertRaisesRegexp(ValueError, "strides in the batch"):
- tf.nn.max_pool_with_argmax(
- tf.placeholder(tf.float32),
- ksize=[1, 1, 1, 1],
- strides=[2, 1, 1, 1],
- padding="SAME")
-
- # Filter larger than input.
- for pool_func in [tf.nn.max_pool_with_argmax]:
- with self.assertRaisesRegexp(ValueError,
- "Filter must not be larger than the input"):
- pool_func(tf.placeholder(tf.float32,
- shape=[32, 20, 20, 3]),
- ksize=[1, 20, 21, 1], strides=[1, 1, 1, 1], padding="SAME")
- with self.assertRaisesRegexp(ValueError,
- "Filter must not be larger than the input"):
- pool_func(tf.placeholder(tf.float32,
- shape=[32, 20, 20, 3]),
- ksize=[1, 21, 20, 1], strides=[1, 1, 1, 1], padding="SAME")
-
def testOpEdgeCases(self):
with self.test_session() as sess:
- for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]:
+ pool_funcs = [tf.nn.max_pool, tf.nn.avg_pool]
+ if tf.test.is_gpu_available():
+ pool_funcs.append(tf.nn.max_pool_with_argmax)
+ for pool_func in pool_funcs:
# Illegal strides.
with self.assertRaisesRegexp(
tf.errors.UnimplementedError,
diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py
index 4550a2831c..1f9632ef03 100644
--- a/tensorflow/python/kernel_tests/topk_op_test.py
+++ b/tensorflow/python/kernel_tests/topk_op_test.py
@@ -81,7 +81,7 @@ class TopKTest(tf.test.TestCase):
def testKTooLarge(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
with self.assertRaisesRegexp(
- ValueError, r"input.shape \(2, 2\) must have last dimension >= k = 4"):
+ ValueError, r"must have last dimension >= k = 4"):
tf.nn.top_k(inputs, 4)
def testTopKGradients(self):
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index c27e594381..9560ffe66d 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -393,9 +393,10 @@ def bias_add(value, bias, data_format=None, name=None):
return gen_nn_ops._bias_add(value, bias, data_format=data_format, name=name)
-ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape)
-
-ops.RegisterShape("BiasAddGrad")(common_shapes.bias_add_grad_shape)
+ops.RegisterShape("BiasAddV1")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BiasAdd")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BiasAddGradV1")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BiasAddGrad")(common_shapes.call_cpp_shape_fn)
# pylint: disable=protected-access
@@ -426,11 +427,6 @@ def bias_add_v1(value, bias, name=None):
return gen_nn_ops._bias_add_v1(value, bias, name=name)
-ops.RegisterShape("BiasAddV1")(common_shapes.bias_add_shape)
-
-ops.RegisterShape("BiasAddGradV1")(common_shapes.bias_add_grad_shape)
-
-
def crelu(features, name=None):
"""Computes Concatenated ReLU.
@@ -866,23 +862,12 @@ ops.RegisterShape("LRNGrad")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Softmax")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("LogSoftmax")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("InTopK")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TopK")(common_shapes.call_cpp_shape_fn)
-@ops.RegisterShape("TopK")
@ops.RegisterShape("TopKV2")
-def _TopKShape(op):
- """Shape function for TopK and TopKV2 ops."""
- input_shape = op.inputs[0].get_shape().with_rank_at_least(1)
- if len(op.inputs) >= 2:
- k = tensor_util.constant_value(op.inputs[1])
- else:
- k = op.get_attr("k")
- last = input_shape[-1].value
- if last is not None and k is not None and last < k:
- raise ValueError("input.shape %s must have last dimension >= k = %d" %
- (input_shape, k))
- output_shape = input_shape[:-1].concatenate([k])
- return [output_shape, output_shape]
+def _TopKV2Shape(op):
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
ops.RegisterShape("BatchNormWithGlobalNormalization")(
@@ -960,24 +945,12 @@ def _FusedResizeAndPadConv2DShape(op):
return [tensor_shape.TensorShape(output_shape)]
-@ops.RegisterShape("MaxPoolWithArgmax")
-def _MaxPoolWithArgMaxShape(op):
- """Shape function for MaxPoolWithArgmax op."""
- return common_shapes.max_pool_shape(op) * 2
+ops.RegisterShape("MaxPoolWithArgmax")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("AvgPoolGrad")
def _AvgPoolGradShape(op):
- """Shape function for the AvgPoolGrad op."""
- orig_input_shape = tensor_util.constant_value(op.inputs[0])
- if orig_input_shape is not None:
- return [tensor_shape.TensorShape(orig_input_shape.tolist())]
- else:
- # NOTE(mrry): We could in principle work out the shape from the
- # gradients and the attrs, but if we do not know orig_input_shape
- # statically, then we are unlikely to know the shape of the
- # gradients either.
- return [tensor_shape.unknown_shape(ndims=4)]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
@ops.RegisterShape("FractionalMaxPool")
@@ -1015,50 +988,22 @@ def _fractional_avg_pool_grad_shape(op):
@ops.RegisterShape("Conv2DBackpropFilter")
def _Conv2DBackpropFilterShape(op):
- """Shape function for the Conv2DBackpropFilter op."""
- filter_shape = tensor_util.constant_value(op.inputs[1])
- if filter_shape is not None:
- return [tensor_shape.TensorShape(filter_shape.tolist())]
- else:
- # NOTE(mrry): We could in principle work out the shape from the
- # gradients and the attrs, but if we do not know filter_shape
- # statically, then we are unlikely to know the shape of the
- # gradients either.
- return [tensor_shape.unknown_shape(ndims=4)]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
@ops.RegisterShape("Conv2DBackpropInput")
def _Conv2DBackpropInputShape(op):
- """Shape function for the Conv2DBackpropInput op."""
- input_shape = tensor_util.constant_value(op.inputs[0])
- if input_shape is not None:
- return [tensor_shape.TensorShape(input_shape.tolist())]
- else:
- # NOTE(mrry): We could in principle work out the shape from the
- # gradients and the attrs, but if we do not know input_shape
- # statically, then we are unlikely to know the shape of the
- # gradients either.
- return [tensor_shape.unknown_shape(ndims=4)]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
@ops.RegisterShape("DepthwiseConv2dNativeBackpropFilter")
def _DepthwiseConv2dNativeBackpropFilterShape(op):
- """Shape function for the DepthwiseConv2dNativeBackpropFilter op."""
- filter_shape = tensor_util.constant_value(op.inputs[1])
- if filter_shape is not None:
- return [tensor_shape.TensorShape(filter_shape.tolist())]
- else:
- return [tensor_shape.unknown_shape(ndims=4)]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
@ops.RegisterShape("DepthwiseConv2dNativeBackpropInput")
def _DepthwiseConv2dNativeBackpropInputShape(op):
- """Shape function for the DepthwiseConv2dNativeBackpropInput op."""
- input_shape = tensor_util.constant_value(op.inputs[0])
- if input_shape is not None:
- return [tensor_shape.TensorShape(input_shape.tolist())]
- else:
- return [tensor_shape.unknown_shape(ndims=4)]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
ops.RegisterShape("MaxPoolGrad")(common_shapes.call_cpp_shape_fn)
@@ -1136,55 +1081,9 @@ def _calc_depthwise_conv_weight_params(graph, node):
filter_channel_multiplier))
-@ops.RegisterShape("Conv3D")
-def _Conv3DShape(op):
- """Shape function for Conv3D."""
- input_shape = op.inputs[0].get_shape().with_rank(5)
- filter_shape = op.inputs[1].get_shape().with_rank(5)
-
- batch_size = input_shape[0]
- out_channels = filter_shape[4]
- # Check that the input number of channels is compatible between
- # input data and filter size.
- input_shape[4].assert_is_compatible_with(filter_shape[3])
-
- stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
- assert stride_b == 1
- assert stride_d == 1
-
- padding_type = op.get_attr("padding")
- out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
- input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c),
- padding_type)
-
- return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
- out_channels])]
-
-
-@ops.RegisterShape("MaxPool3D")
-@ops.RegisterShape("AvgPool3D")
-def _Pool3DShape(op):
- """Shape function for Max/AvgPool3D."""
- input_shape = op.inputs[0].get_shape().with_rank(5)
- ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
- assert ksize_b == 1
- assert ksize_d == 1
-
- stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
- assert stride_b == 1
- assert stride_d == 1
-
- batch_size = input_shape[0]
- channels = input_shape[4]
-
- padding = op.get_attr("padding")
- out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
- input_shape[1:4], (ksize_p, ksize_r, ksize_c),
- (stride_p, stride_r, stride_c), padding)
- return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
- channels])]
-
-
+ops.RegisterShape("Conv3D")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("MaxPool3D")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("AvgPool3D")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Conv3DBackpropFilter")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Conv3DBackpropInput")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Conv3DBackpropFilterV2")(common_shapes.call_cpp_shape_fn)
@@ -1392,46 +1291,7 @@ def conv1d(value, filters, stride, padding,
return array_ops.squeeze(result, [1])
-@ops.RegisterShape("Dilation2D")
-def _Dilation2DShape(op):
- """Shape function for Dilation2D op."""
- input_shape = op.inputs[0].get_shape().with_rank(4)
- filter_shape = op.inputs[1].get_shape().with_rank(3)
-
- batch_size = input_shape[0]
- in_rows = input_shape[1]
- in_cols = input_shape[2]
- depth = input_shape[3]
-
- filter_rows = filter_shape[0]
- filter_cols = filter_shape[1]
- # Check that the input depths are compatible.
- input_shape[3].assert_is_compatible_with(filter_shape[2])
-
- stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
- if stride_b != 1 or stride_d != 1:
- raise ValueError("Current implementation does not yet support "
- "strides in the batch and depth dimensions.")
-
- rate_b, rate_r, rate_c, rate_d = op.get_attr("rates")
- if rate_b != 1 or rate_d != 1:
- raise ValueError("Current implementation does not yet support "
- "rates in the batch and depth dimensions.")
-
- filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_r - 1)
- filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_c - 1)
-
- padding = op.get_attr("padding")
- out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols,
- filter_rows_eff,
- filter_cols_eff,
- stride_r, stride_c,
- padding)
-
- output_shape = [batch_size, out_rows, out_cols, depth]
- return [tensor_shape.TensorShape(output_shape)]
-
-
+ops.RegisterShape("Dilation2D")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Dilation2DBackpropInput")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Dilation2DBackpropFilter")(common_shapes.call_cpp_shape_fn)