aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-29 12:30:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-29 12:39:36 -0700
commitc0502aff716a6b7889c5eb23cd06b5bda414bf9e (patch)
tree509561c28d9f96956ec93eee2eb1199a005f49e0
parent0fb83965a209eb03c1c090e3e540fd7c2c7d1025 (diff)
Internal refactoring.
PiperOrigin-RevId: 170517511
-rw-r--r--tensorflow/python/layers/convolutional.py22
-rw-r--r--tensorflow/python/ops/nn_ops.py574
2 files changed, 383 insertions, 213 deletions
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index 9dec3b5a47..b11a210aca 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -21,12 +21,14 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
class _Conv(base.Layer):
@@ -151,16 +153,22 @@ class _Conv(base.Layer):
self.bias = None
self.input_spec = base.InputSpec(ndim=self.rank + 2,
axes={channel_axis: input_dim})
+ with ops.name_scope(None, 'convolution', [self.kernel]) as name:
+ self._convolution_op = nn_ops.Convolution(
+ input_shape,
+ filter_shape=self.kernel.get_shape(),
+ dilation_rate=self.dilation_rate,
+ strides=self.strides,
+ padding=self.padding.upper(),
+ data_format=utils.convert_data_format(self.data_format,
+ self.rank + 2),
+ name=name)
self.built = True
def call(self, inputs):
- outputs = nn.convolution(
- input=inputs,
- filter=self.kernel,
- dilation_rate=self.dilation_rate,
- strides=self.strides,
- padding=self.padding.upper(),
- data_format=utils.convert_data_format(self.data_format, self.rank + 2))
+ # TODO(agarwal): do we need this name_scope ?
+ with ops.name_scope(None, 'convolution', [inputs, self.kernel]):
+ outputs = self._convolution_op(inputs, self.kernel.value())
if self.use_bias:
if self.data_format == 'channels_first':
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index bd726ca631..21b3129180 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -87,9 +87,43 @@ def _non_atrous_convolution(input, filter, padding, data_format=None, # pylint:
"""
with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
input = ops.convert_to_tensor(input, name="input")
+ input_shape = input.get_shape()
filter = ops.convert_to_tensor(filter, name="filter")
- filter_shape = filter.get_shape().with_rank(input.get_shape().ndims)
- input_shape = input.get_shape().with_rank(filter_shape.ndims)
+ filter_shape = filter.get_shape()
+ op = _NonAtrousConvolution(input_shape,
+ filter_shape=filter_shape,
+ padding=padding,
+ data_format=data_format,
+ strides=strides,
+ name=scope)
+ return op(input, filter)
+
+
+class _NonAtrousConvolution(object):
+ """Helper class for _non_atrous_convolution.
+
+ Note that this class assumes that shapes of input and filter passed to
+ __call__ are compatible with input_shape and filter_shape passed to the
+ constructor.
+
+ Arguments:
+ input_shape: static input shape, i.e. input.get_shape().
+ filter_shape: static filter shape, i.e. filter.get_shape().
+ padding: see _non_atrous_convolution.
+ data_format: see _non_atrous_convolution.
+ strides: see _non_atrous_convolution.
+ name: see _non_atrous_convolution.
+ """
+
+ def __init__(self,
+ input_shape,
+ filter_shape, # pylint: disable=redefined-builtin
+ padding, data_format=None,
+ strides=None, name=None):
+ filter_shape = filter_shape.with_rank(input_shape.ndims)
+ self.padding = padding
+ self.name = name
+ input_shape = input_shape.with_rank(filter_shape.ndims)
if input_shape.ndims is None:
raise ValueError("Rank of convolution must be known")
if input_shape.ndims < 3 or input_shape.ndims > 5:
@@ -109,13 +143,9 @@ def _non_atrous_convolution(input, filter, padding, data_format=None, # pylint:
data_format_2d = "NCHW"
else:
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
- return conv1d(
- value=input,
- filters=filter,
- stride=strides[0],
- padding=padding,
- data_format=data_format_2d,
- name=scope)
+ self.strides = strides[0]
+ self.data_format = data_format_2d
+ self.conv_op = self._conv1d
elif conv_dims == 2:
if data_format is None or data_format == "NHWC":
data_format = "NHWC"
@@ -124,13 +154,9 @@ def _non_atrous_convolution(input, filter, padding, data_format=None, # pylint:
strides = [1, 1] + list(strides)
else:
raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
- return gen_nn_ops.conv2d(
- input=input,
- filter=filter,
- strides=strides,
- padding=padding,
- data_format=data_format,
- name=name)
+ self.strides = strides
+ self.data_format = data_format
+ self.conv_op = gen_nn_ops.conv2d
elif conv_dims == 3:
if data_format is None or data_format == "NDHWC":
strides = [1] + list(strides) + [1]
@@ -139,13 +165,26 @@ def _non_atrous_convolution(input, filter, padding, data_format=None, # pylint:
else:
raise ValueError("data_format must be \"NDHWC\" or \"NCDHW\". Have: %s"
% data_format)
- return gen_nn_ops.conv3d(
- input=input,
- filter=filter,
- strides=strides,
- padding=padding,
- data_format=data_format,
- name=name)
+ self.strides = strides
+ self.data_format = data_format
+ self.conv_op = gen_nn_ops.conv3d
+
+ # Note that we need this adapter since argument names for conv1d don't match
+ # those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
+ # pylint: disable=redefined-builtin
+ def _conv1d(self, input, filter, strides, padding, data_format, name):
+ return conv1d(value=input, filters=filter, stride=strides, padding=padding,
+ data_format=data_format, name=name)
+ # pylint: enable=redefined-builtin
+
+ def __call__(self, inp, filter): # pylint: disable=redefined-builtin
+ return self.conv_op(
+ input=inp,
+ filter=filter,
+ strides=self.strides,
+ padding=self.padding,
+ data_format=self.data_format,
+ name=self.name)
def with_space_to_batch(
@@ -291,172 +330,252 @@ def with_space_to_batch(
"""
input = ops.convert_to_tensor(input, name="input")
- dilation_rate = ops.convert_to_tensor(dilation_rate,
- dtypes.int32,
- name="dilation_rate")
- try:
- rate_shape = dilation_rate.get_shape().with_rank(1)
- except ValueError:
- raise ValueError("rate must be rank 1")
+ input_shape = input.get_shape()
+
+ def build_op(num_spatial_dims, padding):
+ return lambda inp, _: op(inp, num_spatial_dims, padding)
+
+ new_op = _WithSpaceToBatch(input_shape,
+ dilation_rate,
+ padding,
+ build_op,
+ filter_shape=filter_shape,
+ spatial_dims=spatial_dims,
+ data_format=data_format)
+ return new_op(input, None)
+
+
+class _WithSpaceToBatch(object):
+ """Helper class for with_space_to_batch.
+
+ Note that this class assumes that shapes of input and filter passed to
+ __call__ are compatible with input_shape and filter_shape passed to the
+ constructor.
+
+ Arguments
+ input_shape: static shape of input. i.e. input.get_shape().
+ dilation_rate: see with_space_to_batch
+ padding: see with_space_to_batch
+ build_op: Function that maps (num_spatial_dims, paddings) -> (function that
+ maps (input, filter) -> output).
+ filter_shape: see with_space_to_batch
+ spatial_dims: see with_space_to_batch
+ data_format: see with_space_to_batch
+ """
- if not dilation_rate.get_shape().is_fully_defined():
- raise ValueError("rate must have known shape")
+ def __init__(self,
+ input_shape,
+ dilation_rate,
+ padding,
+ build_op,
+ filter_shape=None,
+ spatial_dims=None,
+ data_format=None):
+ """Helper class for _with_space_to_batch."""
+ dilation_rate = ops.convert_to_tensor(dilation_rate,
+ dtypes.int32,
+ name="dilation_rate")
+ try:
+ rate_shape = dilation_rate.get_shape().with_rank(1)
+ except ValueError:
+ raise ValueError("rate must be rank 1")
- num_spatial_dims = rate_shape[0].value
+ if not dilation_rate.get_shape().is_fully_defined():
+ raise ValueError("rate must have known shape")
- if data_format is not None and data_format.startswith("NC"):
- starting_spatial_dim = 2
- else:
- starting_spatial_dim = 1
-
- if spatial_dims is None:
- spatial_dims = range(starting_spatial_dim,
- num_spatial_dims + starting_spatial_dim)
- orig_spatial_dims = list(spatial_dims)
- spatial_dims = sorted(set(int(x) for x in orig_spatial_dims))
- if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
- raise ValueError(
- "spatial_dims must be a montonically increasing sequence of positive "
- "integers") # pylint: disable=line-too-long
+ num_spatial_dims = rate_shape[0].value
- if data_format is not None and data_format.startswith("NC"):
- expected_input_rank = spatial_dims[-1]
- else:
- expected_input_rank = spatial_dims[-1] + 1
-
- try:
- input.get_shape().with_rank_at_least(expected_input_rank)
- except ValueError:
- ValueError("input tensor must have rank %d at least" %
- (expected_input_rank))
-
- const_rate = tensor_util.constant_value(dilation_rate)
- rate_or_const_rate = dilation_rate
- if const_rate is not None:
- rate_or_const_rate = const_rate
- if np.any(const_rate < 1):
- raise ValueError("dilation_rate must be positive")
- if np.all(const_rate == 1):
- return op(input, num_spatial_dims, padding)
-
- # We have two padding contributions. The first is used for converting "SAME"
- # to "VALID". The second is required so that the height and width of the
- # zero-padded value tensor are multiples of rate.
-
- # Padding required to reduce to "VALID" convolution
- if padding == "SAME":
- if filter_shape is None:
- raise ValueError("filter_shape must be specified for SAME padding")
- filter_shape = ops.convert_to_tensor(filter_shape, name="filter_shape")
- const_filter_shape = tensor_util.constant_value(filter_shape)
- if const_filter_shape is not None:
- filter_shape = const_filter_shape
-
- # Spatial dimensions of the filters and the upsampled filters in which we
- # introduce (rate - 1) zeros between consecutive filter values.
- filter_spatial_shape = filter_shape[:num_spatial_dims]
- dilated_filter_spatial_shape = (filter_spatial_shape +
- (filter_spatial_shape - 1) *
- (rate_or_const_rate - 1))
- pad_extra_shape = dilated_filter_spatial_shape - 1
-
- # When full_padding_shape is odd, we pad more at end, following the same
- # convention as conv2d.
- pad_extra_start = pad_extra_shape // 2
- pad_extra_end = pad_extra_shape - pad_extra_start
- base_paddings = array_ops.stack([[pad_extra_start[i], pad_extra_end[i]]
- for i in range(num_spatial_dims)])
- elif padding == "VALID":
- base_paddings = np.zeros([num_spatial_dims, 2], np.int32)
- else:
- raise ValueError("Invalid padding method %r" % padding)
-
- # Handle input whose shape is unknown during graph creation.
- input_spatial_shape = None
- if input.get_shape().ndims is not None:
- input_shape_list = input.get_shape().as_list()
- input_spatial_shape = [input_shape_list[i] for i in spatial_dims]
- if input_spatial_shape is None or None in input_spatial_shape:
- input_shape_tensor = array_ops.shape(input)
- input_spatial_shape = array_ops.stack(
- [input_shape_tensor[i] for i in spatial_dims])
-
- paddings, crops = array_ops.required_space_to_batch_paddings(
- input_shape=input_spatial_shape,
- base_paddings=base_paddings,
- block_shape=dilation_rate)
-
- def adjust(orig, fill_value):
- """Returns an `adjusted` version of `orig` based on `spatial_dims`.
-
- Tensor of the same type as `orig` and with shape
- `[max(spatial_dims), ...]` where:
-
- adjusted[spatial_dims[i] - 1, ...] = orig[i, ...]
-
- for 0 <= i < len(spatial_dims), and
-
- adjusted[j, ...] = fill_value
-
- for j != spatial_dims[i] - 1 for some i.
-
- If `orig` is a constant value, then the result will be a constant value.
-
- Args:
- orig: Tensor of rank > max(spatial_dims).
- fill_value: Numpy scalar (of same data type as `orig) specifying the fill
- value for non-spatial dimensions.
-
- Returns:
- `adjusted` tensor.
- """
- fill_dims = orig.get_shape().as_list()[1:]
- dtype = orig.dtype.as_numpy_dtype
- parts = []
- const_orig = tensor_util.constant_value(orig)
- const_or_orig = const_orig if const_orig is not None else orig
- prev_spatial_dim = 0
- i = 0
- while i < len(spatial_dims):
- start_i = i
- start_spatial_dim = spatial_dims[i]
- if start_spatial_dim > 1:
- # Fill in any gap from the previous spatial dimension (or dimension 1 if
- # this is the first spatial dimension) with `fill_value`.
- parts.append(
- np.full(
- [start_spatial_dim - 1 - prev_spatial_dim] + fill_dims,
- fill_value,
- dtype=dtype))
- # Find the largest value of i such that:
- # [spatial_dims[start_i], ..., spatial_dims[i]]
- # == [start_spatial_dim, ..., start_spatial_dim + i - start_i],
- # i.e. the end of a contiguous group of spatial dimensions.
- while (i + 1 < len(spatial_dims) and
- spatial_dims[i + 1] == spatial_dims[i] + 1):
- i += 1
- parts.append(const_or_orig[start_i:i + 1])
- prev_spatial_dim = spatial_dims[i]
- i += 1
- if const_orig is not None:
- return np.concatenate(parts)
+ if data_format is not None and data_format.startswith("NC"):
+ starting_spatial_dim = 2
else:
- return array_ops.concat(parts, 0)
+ starting_spatial_dim = 1
+
+ if spatial_dims is None:
+ spatial_dims = range(starting_spatial_dim,
+ num_spatial_dims + starting_spatial_dim)
+ orig_spatial_dims = list(spatial_dims)
+ spatial_dims = sorted(set(int(x) for x in orig_spatial_dims))
+ if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
+ raise ValueError(
+ "spatial_dims must be a montonically increasing sequence of positive "
+ "integers") # pylint: disable=line-too-long
+
+ if data_format is not None and data_format.startswith("NC"):
+ expected_input_rank = spatial_dims[-1]
+ else:
+ expected_input_rank = spatial_dims[-1] + 1
- dilation_rate = adjust(dilation_rate, 1)
- paddings = adjust(paddings, 0)
- crops = adjust(crops, 0)
+ try:
+ input_shape.with_rank_at_least(expected_input_rank)
+ except ValueError:
+ ValueError("input tensor must have rank %d at least" %
+ (expected_input_rank))
+
+ const_rate = tensor_util.constant_value(dilation_rate)
+ rate_or_const_rate = dilation_rate
+ if const_rate is not None:
+ rate_or_const_rate = const_rate
+ if np.any(const_rate < 1):
+ raise ValueError("dilation_rate must be positive")
+ if np.all(const_rate == 1):
+ self.call = build_op(num_spatial_dims, padding)
+ return
+
+ # We have two padding contributions. The first is used for converting "SAME"
+ # to "VALID". The second is required so that the height and width of the
+ # zero-padded value tensor are multiples of rate.
- input_converted = array_ops.space_to_batch_nd(
- input=input,
- block_shape=dilation_rate,
- paddings=paddings)
+ # Padding required to reduce to "VALID" convolution
+ if padding == "SAME":
+ if filter_shape is None:
+ raise ValueError("filter_shape must be specified for SAME padding")
+ filter_shape = ops.convert_to_tensor(filter_shape, name="filter_shape")
+ const_filter_shape = tensor_util.constant_value(filter_shape)
+ if const_filter_shape is not None:
+ filter_shape = const_filter_shape
+ self.base_paddings = _with_space_to_batch_base_paddings(
+ const_filter_shape,
+ num_spatial_dims,
+ rate_or_const_rate)
+ else:
+ self.num_spatial_dims = num_spatial_dims
+ self.rate_or_const_rate = rate_or_const_rate
+ self.base_paddings = None
+ elif padding == "VALID":
+ self.base_paddings = np.zeros([num_spatial_dims, 2], np.int32)
+ else:
+ raise ValueError("Invalid padding method %r" % padding)
+
+ self.input_shape = input_shape
+ self.spatial_dims = spatial_dims
+ self.dilation_rate = dilation_rate
+ self.op = build_op(num_spatial_dims, "VALID")
+ self.call = self._with_space_to_batch_call
+
+ def _with_space_to_batch_call(self, inp, filter): # pylint: disable=redefined-builtin
+ """Call functionality for with_space_to_batch."""
+ # Handle input whose shape is unknown during graph creation.
+ input_spatial_shape = None
+ input_shape = self.input_shape
+ spatial_dims = self.spatial_dims
+ if input_shape.ndims is not None:
+ input_shape_list = input_shape.as_list()
+ input_spatial_shape = [input_shape_list[i] for i in spatial_dims]
+ if input_spatial_shape is None or None in input_spatial_shape:
+ input_shape_tensor = array_ops.shape(inp)
+ input_spatial_shape = array_ops.stack(
+ [input_shape_tensor[i] for i in spatial_dims])
+
+ base_paddings = self.base_paddings
+ if base_paddings is None:
+ # base_paddings could not be computed at build time since static filter
+ # shape was not fully defined.
+ filter_shape = array_ops.shape(filter)
+ base_paddings = _with_space_to_batch_base_paddings(
+ filter_shape,
+ self.num_spatial_dims,
+ self.rate_or_const_rate)
+ paddings, crops = array_ops.required_space_to_batch_paddings(
+ input_shape=input_spatial_shape,
+ base_paddings=base_paddings,
+ block_shape=self.dilation_rate)
+
+ dilation_rate = _with_space_to_batch_adjust(self.dilation_rate, 1,
+ spatial_dims)
+ paddings = _with_space_to_batch_adjust(paddings, 0, spatial_dims)
+ crops = _with_space_to_batch_adjust(crops, 0, spatial_dims)
+ input_converted = array_ops.space_to_batch_nd(
+ input=inp,
+ block_shape=dilation_rate,
+ paddings=paddings)
+
+ result = self.op(input_converted, filter)
+
+ result_converted = array_ops.batch_to_space_nd(
+ input=result, block_shape=dilation_rate, crops=crops)
+ return result_converted
+
+ def __call__(self, inp, filter): # pylint: disable=redefined-builtin
+ return self.call(inp, filter)
+
+
+def _with_space_to_batch_base_paddings(filter_shape, num_spatial_dims,
+ rate_or_const_rate):
+ """Helper function to compute base_paddings."""
+ # Spatial dimensions of the filters and the upsampled filters in which we
+ # introduce (rate - 1) zeros between consecutive filter values.
+ filter_spatial_shape = filter_shape[:num_spatial_dims]
+ dilated_filter_spatial_shape = (filter_spatial_shape +
+ (filter_spatial_shape - 1) *
+ (rate_or_const_rate - 1))
+ pad_extra_shape = dilated_filter_spatial_shape - 1
+
+ # When full_padding_shape is odd, we pad more at end, following the same
+ # convention as conv2d.
+ pad_extra_start = pad_extra_shape // 2
+ pad_extra_end = pad_extra_shape - pad_extra_start
+ base_paddings = array_ops.stack([[pad_extra_start[i], pad_extra_end[i]]
+ for i in range(num_spatial_dims)])
+ return base_paddings
+
+
+def _with_space_to_batch_adjust(orig, fill_value, spatial_dims):
+ """Returns an `adjusted` version of `orig` based on `spatial_dims`.
+
+ Tensor of the same type as `orig` and with shape
+ `[max(spatial_dims), ...]` where:
+
+ adjusted[spatial_dims[i] - 1, ...] = orig[i, ...]
+
+ for 0 <= i < len(spatial_dims), and
+
+ adjusted[j, ...] = fill_value
+
+ for j != spatial_dims[i] - 1 for some i.
+
+ If `orig` is a constant value, then the result will be a constant value.
- result = op(input_converted, num_spatial_dims, "VALID")
+ Args:
+ orig: Tensor of rank > max(spatial_dims).
+ fill_value: Numpy scalar (of same data type as `orig) specifying the fill
+ value for non-spatial dimensions.
+ spatial_dims: See with_space_to_batch.
- result_converted = array_ops.batch_to_space_nd(
- input=result, block_shape=dilation_rate, crops=crops)
- return result_converted
+ Returns:
+ `adjusted` tensor.
+ """
+ fill_dims = orig.get_shape().as_list()[1:]
+ dtype = orig.dtype.as_numpy_dtype
+ parts = []
+ const_orig = tensor_util.constant_value(orig)
+ const_or_orig = const_orig if const_orig is not None else orig
+ prev_spatial_dim = 0
+ i = 0
+ while i < len(spatial_dims):
+ start_i = i
+ start_spatial_dim = spatial_dims[i]
+ if start_spatial_dim > 1:
+ # Fill in any gap from the previous spatial dimension (or dimension 1 if
+ # this is the first spatial dimension) with `fill_value`.
+ parts.append(
+ np.full(
+ [start_spatial_dim - 1 - prev_spatial_dim] + fill_dims,
+ fill_value,
+ dtype=dtype))
+ # Find the largest value of i such that:
+ # [spatial_dims[start_i], ..., spatial_dims[i]]
+ # == [start_spatial_dim, ..., start_spatial_dim + i - start_i],
+ # i.e. the end of a contiguous group of spatial dimensions.
+ while (i + 1 < len(spatial_dims) and
+ spatial_dims[i + 1] == spatial_dims[i] + 1):
+ i += 1
+ parts.append(const_or_orig[start_i:i + 1])
+ prev_spatial_dim = spatial_dims[i]
+ i += 1
+ if const_orig is not None:
+ return np.concatenate(parts)
+ else:
+ return array_ops.concat(parts, 0)
def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate):
@@ -620,58 +739,100 @@ def convolution(input, filter, # pylint: disable=redefined-builtin
# pylint: enable=line-too-long
with ops.name_scope(name, "convolution", [input, filter]) as name:
input = ops.convert_to_tensor(input, name="input")
+ input_shape = input.get_shape()
filter = ops.convert_to_tensor(filter, name="filter")
- num_total_dims = filter.get_shape().ndims
+ filter_shape = filter.get_shape()
+ op = Convolution(input_shape,
+ filter_shape,
+ padding,
+ strides=strides,
+ dilation_rate=dilation_rate,
+ name=name, data_format=data_format)
+ return op(input, filter)
+
+
+class Convolution(object):
+ """Helper class for convolution.
+
+ Note that this class assumes that shapes of input and filter passed to
+ __call__ are compatible with input_shape and filter_shape passed to the
+ constructor.
+
+ Arguments
+ input_shape: static shape of input. i.e. input.get_shape().
+ filter_shape: static shape of the filter. i.e. filter.get_shape().
+ padding: see convolution.
+ strides: see convolution.
+ dilation_rate: see convolution.
+ name: see convolution.
+ data_format: see convolution.
+ """
+
+ def __init__(self,
+ input_shape,
+ filter_shape,
+ padding, strides=None, dilation_rate=None,
+ name=None, data_format=None):
+ """Helper function for convolution."""
+ num_total_dims = filter_shape.ndims
if num_total_dims is None:
- num_total_dims = input.get_shape().ndims
+ num_total_dims = input_shape.ndims
if num_total_dims is None:
raise ValueError("rank of input or filter must be known")
num_spatial_dims = num_total_dims - 2
try:
- input.get_shape().with_rank(num_spatial_dims + 2)
+ input_shape.with_rank(num_spatial_dims + 2)
except ValueError:
ValueError("input tensor must have rank %d" % (num_spatial_dims + 2))
try:
- filter.get_shape().with_rank(num_spatial_dims + 2)
+ filter_shape.with_rank(num_spatial_dims + 2)
except ValueError:
ValueError("filter tensor must have rank %d" % (num_spatial_dims + 2))
if data_format is None or not data_format.startswith("NC"):
- input_channels_dim = input.get_shape()[num_spatial_dims + 1]
+ input_channels_dim = input_shape[num_spatial_dims + 1]
spatial_dims = range(1, num_spatial_dims+1)
else:
- input_channels_dim = input.get_shape()[1]
+ input_channels_dim = input_shape[1]
spatial_dims = range(2, num_spatial_dims+2)
- if not input_channels_dim.is_compatible_with(filter.get_shape()[
+ if not input_channels_dim.is_compatible_with(filter_shape[
num_spatial_dims]):
raise ValueError(
- "number of input channels does not match corresponding dimension of filter, "
- "{} != {}".format(input_channels_dim, filter.get_shape()[
+ "number of input channels does not match corresponding dimension of "
+ "filter, {} != {}".format(input_channels_dim, filter_shape[
num_spatial_dims]))
strides, dilation_rate = _get_strides_and_dilation_rate(
num_spatial_dims, strides, dilation_rate)
- def op(input_converted, _, padding):
- return _non_atrous_convolution(
- input=input_converted,
- filter=filter,
- padding=padding,
- data_format=data_format,
- strides=strides,
- name=name)
-
- return with_space_to_batch(
- input=input,
- filter_shape=array_ops.shape(filter),
- spatial_dims=spatial_dims,
+ self.input_shape = input_shape
+ self.filter_shape = filter_shape
+ self.data_format = data_format
+ self.strides = strides
+ self.name = name
+ self.conv_op = _WithSpaceToBatch(
+ input_shape,
dilation_rate=dilation_rate,
padding=padding,
- op=op)
+ build_op=self._build_op,
+ filter_shape=filter_shape,
+ spatial_dims=spatial_dims)
+
+ def _build_op(self, _, padding):
+ return _NonAtrousConvolution(
+ self.input_shape,
+ filter_shape=self.filter_shape,
+ padding=padding,
+ data_format=self.data_format,
+ strides=self.strides,
+ name=self.name)
+
+ def __call__(self, inp, filter): # pylint: disable=redefined-builtin
+ return self.conv_op(inp, filter)
def pool(input, # pylint: disable=redefined-builtin
@@ -977,7 +1138,7 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
def conv2d_transpose(value,
- filter,
+ filter, # pylint: disable=redefined-builtin
output_shape,
strides,
padding="SAME",
@@ -1196,7 +1357,7 @@ def atrous_conv2d_transpose(value,
def conv3d_transpose(value,
- filter,
+ filter, # pylint: disable=redefined-builtin
output_shape,
strides,
padding="SAME",
@@ -1328,7 +1489,7 @@ def crelu(features, name=None):
Concatenates a ReLU which selects only the positive part of the activation
with a ReLU which selects only the *negative* part of the activation.
Note that as a result this non-linearity doubles the depth of the activations.
- Source: [Understanding and Improving Convolutional Neural Networks via Concatenated Rectified Linear Units. W. Shang, et al.](https://arxiv.org/abs/1603.05201)
+ Source: [Understanding and Improving Convolutional Neural Networks via Concatenated Rectified Linear Units. W. Shang, et al.](https://arxiv.org/abs/1603.05201)
Args:
features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
@@ -2115,6 +2276,7 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
padding=padding,
name=name))
+
def in_top_k(predictions, targets, k, name=None):
r"""Says whether the targets are in the top `K` predictions.