aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_ops.py
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2016-09-29 15:05:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-29 16:17:09 -0700
commit1283b84a49a9f5e14aca833cf981b61848aaf916 (patch)
treeffba9d2d8ba549bd5981cc84748d2db8858fc676 /tensorflow/python/ops/nn_ops.py
parentef9f5fee0a079f6bed445064e8e9d18fb7a904d8 (diff)
Merge changes from github.
Change: 134721831
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r--tensorflow/python/ops/nn_ops.py48
1 files changed, 33 insertions, 15 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 6486463b2c..b70cf83d91 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1429,28 +1429,34 @@ ops.RegisterShape("AvgPool")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("MaxPool")(common_shapes.call_cpp_shape_fn)
-@ops.RegisterShape("FusedResizeAndPadConv2D")
-def _FusedResizeAndPadConv2DShape(op):
- """Shape function for FusedResizeAndPadConv2D op."""
+def _CommonFusedConvCalculations(op, has_resize):
+ """Shape function for Fused*Conv2D ops."""
# The bilinear resize shape calculation.
input_shape = op.inputs[0].get_shape().with_rank(4)
- unused_size_shape = op.inputs[1].get_shape().merge_with([2])
- size = tensor_util.constant_value(op.inputs[1])
- if size is not None:
- height = size[0]
- width = size[1]
+ if has_resize:
+ unused_size_shape = op.inputs[1].get_shape().merge_with([2])
+ size = tensor_util.constant_value(op.inputs[1])
+ if size is not None:
+ height = size[0]
+ width = size[1]
+ else:
+ height = None
+ width = None
+ resized_shape = tensor_shape.TensorShape(
+ [input_shape[0], height, width, input_shape[3]])
+ paddings_index = 2
+ filter_index = 3
else:
- height = None
- width = None
- resized_shape = tensor_shape.TensorShape(
- [input_shape[0], height, width, input_shape[3]])
+ resized_shape = input_shape
+ paddings_index = 1
+ filter_index = 2
# Calculates the effect of the padding.
- paddings_shape = op.inputs[2].get_shape().with_rank(2)
+ paddings_shape = op.inputs[paddings_index].get_shape().with_rank(2)
resized_shape = resized_shape.with_rank(paddings_shape[0].value)
paddings_shape = paddings_shape.merge_with(
tensor_shape.matrix(resized_shape.ndims, 2))
- paddings = tensor_util.constant_value(op.inputs[2])
+ paddings = tensor_util.constant_value(op.inputs[paddings_index])
if paddings is None:
padded_shape = tensor_shape.unknown_shape(ndims=resized_shape.ndims)
else:
@@ -1462,7 +1468,7 @@ def _FusedResizeAndPadConv2DShape(op):
padded_shape = tensor_shape.TensorShape(output_dims)
# Finally work out the convolution's effect.
- filter_shape = op.inputs[3].get_shape().with_rank(4)
+ filter_shape = op.inputs[filter_index].get_shape().with_rank(4)
batch_size = padded_shape[0]
in_rows = padded_shape[1]
@@ -1494,6 +1500,18 @@ def _FusedResizeAndPadConv2DShape(op):
return [tensor_shape.TensorShape(output_shape)]
+@ops.RegisterShape("FusedResizeAndPadConv2D")
+def _FusedResizeAndPadConv2DShape(op):
+ """Shape function for FusedResizeAndPadConv2D op."""
+ return _CommonFusedConvCalculations(op, True)
+
+
+@ops.RegisterShape("FusedPadConv2D")
+def _FusedPadConv2DShape(op):
+ """Shape function for FusedResizeAndPadConv2D op."""
+ return _CommonFusedConvCalculations(op, False)
+
+
ops.RegisterShape("MaxPoolWithArgmax")(common_shapes.call_cpp_shape_fn)