aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r--tensorflow/python/ops/nn_ops.py30
1 files changed, 28 insertions, 2 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 47f48a7e16..8fbe698914 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -2215,6 +2215,31 @@ def xw_plus_b_v1(x, weights, biases, name=None): # pylint: disable=invalid-name
return bias_add_v1(mm, biases, name=name)
+def _get_noise_shape(x, noise_shape):
+ # If noise_shape is none return immediately.
+ if noise_shape is None:
+ return array_ops.shape(x)
+
+ try:
+ # Best effort to figure out the intended shape.
+ # If not possible, let the op to handle it.
+ # In eager mode exception will show up.
+ noise_shape_ = tensor_shape.as_shape(noise_shape)
+ except (TypeError, ValueError):
+ return noise_shape
+
+ if x.shape.dims is not None and len(x.shape.dims) == len(noise_shape_.dims):
+ new_dims = []
+ for i, dim in enumerate(x.shape.dims):
+ if noise_shape_.dims[i].value is None and dim.value is not None:
+ new_dims.append(dim.value)
+ else:
+ new_dims.append(noise_shape_.dims[i].value)
+ return tensor_shape.TensorShape(new_dims)
+
+ return noise_shape
+
+
@tf_export("nn.dropout")
def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
"""Computes dropout.
@@ -2265,7 +2290,8 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di
if tensor_util.constant_value(keep_prob) == 1:
return x
- noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)
+ noise_shape = _get_noise_shape(x, noise_shape)
+
# uniform [keep_prob, 1.0 + keep_prob)
random_tensor = keep_prob
random_tensor += random_ops.random_uniform(
@@ -2380,7 +2406,7 @@ def conv1d(value,
Args:
value: A 3D `Tensor`. Must be of type `float16` or `float32`.
- filters: A 3D `Tensor`. Must have the same type as `input`.
+ filters: A 3D `Tensor`. Must have the same type as `value`.
stride: An `integer`. The number of entries by which
the filter is moved right at each step.
padding: 'SAME' or 'VALID'