aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/estimator/model_fn.py4
-rw-r--r--tensorflow/python/framework/ops.py30
-rw-r--r--tensorflow/python/framework/ops_test.py9
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py4
-rw-r--r--tensorflow/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py40
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py23
-rw-r--r--tensorflow/python/kernel_tests/sparse_slice_op_test.py22
-rw-r--r--tensorflow/python/ops/array_grad.py8
-rw-r--r--tensorflow/python/ops/control_flow_ops.py1
-rw-r--r--tensorflow/python/ops/init_ops.py24
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py3
-rw-r--r--tensorflow/python/ops/nn_ops.py3
-rw-r--r--tensorflow/python/ops/sparse_grad.py29
14 files changed, 174 insertions, 27 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 009ac9d8fd..a9fd8f8e1a 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -99,7 +99,7 @@ class EstimatorSpec(
ignored in eval and infer modes. Example:
```python
- def my_model_fn(mode, features, labels):
+ def my_model_fn(features, labels, mode):
predictions = ...
loss = ...
train_op = ...
@@ -114,7 +114,7 @@ class EstimatorSpec(
given mode. Example:
```python
- def my_model_fn(mode, features, labels):
+ def my_model_fn(features, labels, mode):
if (mode == tf.estimator.ModeKeys.TRAIN or
mode == tf.estimator.ModeKeys.EVAL):
loss = ...
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 89afd1d25b..cf0b1e36fb 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -3239,8 +3239,9 @@ class Graph(object):
# the name will still appear in _names_in_use even though the name hasn't
# been used. This is ok, just leave _names_in_use as-is in this case.
# TODO(skyewm): make the C API guarantee no name conflicts.
- if ret.name not in self._names_in_use:
- self._names_in_use[ret.name] = 1
+ name_key = ret.name.lower()
+ if name_key not in self._names_in_use:
+ self._names_in_use[name_key] = 1
self._create_op_helper(ret, compute_device=compute_device)
return ret
@@ -3949,20 +3950,27 @@ class Graph(object):
"""
if self._name_stack:
name = self._name_stack + "/" + name
- i = self._names_in_use.get(name, 0)
- # Increment the number for "name".
+
+ # For the sake of checking for names in use, we treat names as case
+ # insensitive (e.g. foo = Foo).
+ name_key = name.lower()
+ i = self._names_in_use.get(name_key, 0)
+ # Increment the number for "name_key".
if mark_as_used:
- self._names_in_use[name] = i + 1
+ self._names_in_use[name_key] = i + 1
if i > 0:
- base_name = name
- # Make sure the composed name is not already used.
- while name in self._names_in_use:
- name = "%s_%d" % (base_name, i)
+ base_name_key = name_key
+ # Make sure the composed name key is not already used.
+ while name_key in self._names_in_use:
+ name_key = "%s_%d" % (base_name_key, i)
i += 1
- # Mark the composed name as used in case someone wants
+ # Mark the composed name_key as used in case someone wants
# to call unique_name("name_1").
if mark_as_used:
- self._names_in_use[name] = 1
+ self._names_in_use[name_key] = 1
+
+ # Return the new name with the original capitalization of the given name.
+ name = "%s_%d" % (name, i-1)
return name
def get_name_scope(self):
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index c72406e92b..150100d771 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -965,6 +965,15 @@ class NameStackTest(test_util.TensorFlowTestCase):
self.assertEqual("foo_1", g.unique_name("foo"))
self.assertEqual("foo_3", g.unique_name("foo"))
+ def testUniqueNameCaseInsensitivity(self):
+ g = ops.Graph()
+ self.assertEqual("foo", g.unique_name("foo"))
+ self.assertEqual("Foo_1", g.unique_name("Foo"))
+ with g.name_scope("bar"):
+ self.assertEqual("bar/foo", g.unique_name("foo"))
+ with g.name_scope("Bar"):
+ self.assertEqual("Bar_1/foo", g.unique_name("foo"))
+
def testInvalidNameRaisesError(self):
g = ops.Graph()
with g.name_scope(""): # Should not raise
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 2c9f391d01..7d07c77c79 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -1390,7 +1390,7 @@ class LayoutOptimizerTest(test.TestCase):
expected_num_transposes = 3
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes)
- self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes)
+ self._assert_trans_nchw_to_nhwc('map/while/Add_1-0-2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testLoopWithVecAnd4D(self):
@@ -1414,7 +1414,7 @@ class LayoutOptimizerTest(test.TestCase):
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes)
- self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes)
+ self._assert_trans_nchw_to_nhwc('map/while/Add_1-0-2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testBinaryOpSecondPort(self):
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 5796c874f9..8a6614c837 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -893,6 +893,7 @@ tf_py_test(
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
+ "//tensorflow/python:sparse_grad",
"//tensorflow/python:sparse_ops",
],
)
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 795aa67248..927ca012ae 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -364,14 +364,52 @@ class UniformUnitScalingInitializationTest(test.TestCase):
class VarianceScalingInitializationTest(test.TestCase):
+ def testTruncatedNormalDistribution(self):
+ shape = [100, 100]
+ expect_mean = 0.
+ expect_var = 1. / shape[0]
+ init = init_ops.variance_scaling_initializer(
+ distribution='truncated_normal')
+
+ with self.test_session(use_gpu=True), \
+ test.mock.patch.object(
+ random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \
+ as mock_truncated_normal:
+ x = init(shape).eval()
+ self.assertTrue(mock_truncated_normal.called)
+
+ self.assertNear(np.mean(x), expect_mean, err=1e-2)
+ self.assertNear(np.var(x), expect_var, err=1e-2)
+
def testNormalDistribution(self):
shape = [100, 100]
expect_mean = 0.
expect_var = 1. / shape[0]
init = init_ops.variance_scaling_initializer(distribution='normal')
- with self.test_session(use_gpu=True):
+ with self.test_session(use_gpu=True), \
+ test.mock.patch.object(
+ random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \
+ as mock_truncated_normal:
+ x = init(shape).eval()
+ self.assertTrue(mock_truncated_normal.called)
+
+ self.assertNear(np.mean(x), expect_mean, err=1e-2)
+ self.assertNear(np.var(x), expect_var, err=1e-2)
+
+ def testUntruncatedNormalDistribution(self):
+ shape = [100, 100]
+ expect_mean = 0.
+ expect_var = 1. / shape[0]
+ init = init_ops.variance_scaling_initializer(
+ distribution='untruncated_normal')
+
+ with self.test_session(use_gpu=True), \
+ test.mock.patch.object(
+ random_ops, 'random_normal', wraps=random_ops.random_normal) \
+ as mock_random_normal:
x = init(shape).eval()
+ self.assertTrue(mock_random_normal.called)
self.assertNear(np.mean(x), expect_mean, err=1e-2)
self.assertNear(np.var(x), expect_var, err=1e-2)
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 7368251ab6..34e34d9d1b 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -642,6 +642,29 @@ class TileTest(test.TestCase):
err = gradient_checker.compute_gradient_error(a, [4, 2], tiled, [4, 4])
self.assertLess(err, 1e-3)
+ def testGradientWithSparseGradWithRank1(self):
+ inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0],
+ dtype=dtypes.float32)
+ outputs = array_ops.gather(array_ops.tile(inputs, [3]),
+ [1, 5, 9, 3, 7, 2, 2, 2])
+ with self.test_session():
+ error = gradient_checker.compute_gradient_error(
+ inputs, inputs.get_shape().as_list(),
+ outputs, outputs.get_shape().as_list())
+ self.assertLess(error, 1e-4)
+
+ def testGradientWithSparseGradWithRank3(self):
+ inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0],
+ dtype=dtypes.float32)
+ inputs = array_ops.reshape(inputs, [-1, 1, 1])
+ outputs = array_ops.gather(array_ops.tile(inputs, [3, 4, 2]),
+ [1, 5, 9, 3, 7, 2, 2, 2])
+ with self.test_session():
+ error = gradient_checker.compute_gradient_error(
+ inputs, inputs.get_shape().as_list(),
+ outputs, outputs.get_shape().as_list())
+ self.assertLess(error, 1e-4)
+
def testShapeFunctionEdgeCases(self):
# Unknown multiples shape.
inp = constant_op.constant(0.0, shape=[4, 4, 4, 4])
diff --git a/tensorflow/python/kernel_tests/sparse_slice_op_test.py b/tensorflow/python/kernel_tests/sparse_slice_op_test.py
index da116601f8..97f30daf4a 100644
--- a/tensorflow/python/kernel_tests/sparse_slice_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_slice_op_test.py
@@ -21,13 +21,15 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import sparse_ops
+import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
class SparseSliceOpTest(test.TestCase):
- def _SparseTensor_4x6(self):
+ def _SparseTensor_4x6(self, val_dtype=np.int64):
# [0 | |2 | |4 |5 ]
# [ |11| |13|14| ]
# [20| | |23| |25]
@@ -37,7 +39,7 @@ class SparseSliceOpTest(test.TestCase):
[2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype(
np.int64)
val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype(
- np.int64)
+ val_dtype)
shape = np.array([4, 6]).astype(np.int64)
return sparse_tensor.SparseTensor(ind, val, shape)
@@ -244,6 +246,22 @@ class SparseSliceOpTest(test.TestCase):
self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35])
self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1])
+ def testGradients(self):
+ sp_input = self._SparseTensor_4x6(val_dtype=np.float32)
+ start_and_size = [([0, 0], [4, 2]),
+ ([0, 2], [5, 2]),
+ ([0, 4], [5, 3])]
+
+ with self.test_session(use_gpu=False):
+ for start, size in start_and_size:
+ sp_output = sparse_ops.sparse_slice(sp_input, start, size)
+ nnz_in = len(sp_input.values.eval())
+ nnz_out = len(sp_output.values.eval())
+
+ err = gradient_checker.compute_gradient_error(
+ [sp_input.values], [(nnz_in,)], sp_output.values, (nnz_out,))
+ self.assertLess(err, 1e-3)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 3678bd4c1f..fe459a96b9 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -568,7 +568,6 @@ ops.NotDifferentiable("Size")
@ops.RegisterGradient("Tile")
def _TileGrad(op, grad):
"""Sum reduces grad along the tiled dimensions."""
- assert isinstance(grad, ops.Tensor)
input_shape = array_ops.shape(op.inputs[0])
# We interleave multiples and input_shape to get split_shape,
# reshape grad to split_shape, and reduce along all even
@@ -581,6 +580,13 @@ def _TileGrad(op, grad):
split_shape = array_ops.reshape(
array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1])
axes = math_ops.range(0, array_ops.size(split_shape), 2)
+ # Sum reduces grad along the first dimension for IndexedSlices
+ if isinstance(grad, ops.IndexedSlices):
+ grad = math_ops.unsorted_segment_sum(
+ grad.values,
+ math_ops.mod(grad.indices, input_shape[0]),
+ input_shape[0])
+ split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
# Fix shape inference
if not context.executing_eagerly():
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index c8442b42d5..fc37805c79 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -3135,6 +3135,7 @@ def while_loop(cond,
happen is that the thread updating `x` can never get ahead of the
counter thread because the thread incrementing `x` depends on the value
of the counter.
+
```python
import tensorflow as tf
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index c41e952167..5bfc5ce2a7 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -43,7 +43,8 @@ from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.util.deprecation import deprecated
+from tensorflow.python.util.deprecation import (
+ deprecated, deprecated_arg_values)
from tensorflow.python.util.tf_export import tf_export
@@ -409,8 +410,10 @@ class UniformUnitScaling(Initializer):
class VarianceScaling(Initializer):
"""Initializer capable of adapting its scale to the shape of weights tensors.
- With `distribution="normal"`, samples are drawn from a truncated normal
- distribution centered on zero, with `stddev = sqrt(scale / n)`
+ With `distribution="truncated_normal" or "untruncated_normal"`,
+ samples are drawn from a truncated/untruncated normal
+ distribution with a mean of zero and a standard deviation (after truncation,
+ if used) `stddev = sqrt(scale / n)`
where n is:
- number of input units in the weight tensor, if mode = "fan_in"
- number of output units, if mode = "fan_out"
@@ -433,10 +436,14 @@ class VarianceScaling(Initializer):
"distribution" arguments.
"""
+ @deprecated_arg_values(
+ None,
+ "`normal` is a deprecated alias for `truncated_normal`",
+ distribution="normal")
def __init__(self,
scale=1.0,
mode="fan_in",
- distribution="normal",
+ distribution="truncated_normal",
seed=None,
dtype=dtypes.float32):
if scale <= 0.:
@@ -444,7 +451,8 @@ class VarianceScaling(Initializer):
if mode not in {"fan_in", "fan_out", "fan_avg"}:
raise ValueError("Invalid `mode` argument:", mode)
distribution = distribution.lower()
- if distribution not in {"normal", "uniform"}:
+ if distribution not in {"normal", "uniform",
+ "truncated_normal", "untruncated_normal"}:
raise ValueError("Invalid `distribution` argument:", distribution)
self.scale = scale
self.mode = mode
@@ -466,11 +474,15 @@ class VarianceScaling(Initializer):
scale /= max(1., fan_out)
else:
scale /= max(1., (fan_in + fan_out) / 2.)
- if self.distribution == "normal":
+ if self.distribution == "normal" or self.distribution == "truncated_normal":
# constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
stddev = math.sqrt(scale) / .87962566103423978
return random_ops.truncated_normal(
shape, 0.0, stddev, dtype, seed=self.seed)
+ elif self.distribution == "untruncated_normal":
+ stddev = math.sqrt(scale)
+ return random_ops.random_normal(
+ shape, 0.0, stddev, dtype, seed=self.seed)
else:
limit = math.sqrt(3.0 * scale)
return random_ops.random_uniform(
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 9ba91772f5..66633c8b12 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -878,7 +878,8 @@ def sparse_softmax_cross_entropy(
exception when this op is run on CPU, and return `NaN` for corresponding
loss and gradient rows on GPU.
logits: Unscaled log probabilities of shape
- `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`.
+ `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32` or
+ `float64`.
weights: Coefficients for the loss. This must be scalar or broadcastable to
`labels` (i.e. same rank and each dimension is either 1 or the same).
scope: the scope for the operations performed in computing the loss.
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 5a3b669c28..41d54a6c2f 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -2009,7 +2009,8 @@ def sparse_softmax_cross_entropy_with_logits(
exception when this op is run on CPU, and return `NaN` for corresponding
loss and gradient rows on GPU.
logits: Unscaled log probabilities of shape
- `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`.
+ `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or
+ `float64`.
name: A name for the operation (optional).
Returns:
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index 97353d6c74..1223b290ff 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -116,6 +116,35 @@ def _SparseReduceSumGrad(op, out_grad):
None, None)
+@ops.RegisterGradient("SparseSlice")
+def _SparseSliceGrad(op, *grads):
+ """The backward operator for the SparseSlice op.
+
+ This op takes in the upstream gradient w.r.t. non-empty values of
+ the sliced `SparseTensor`, and outputs the gradients w.r.t.
+ the non-empty values of input `SparseTensor`.
+
+ Args:
+ op: the SparseSlice op
+ *grads: the incoming gradients, one element per output of `op`
+
+ Returns:
+ Gradient for each of the 5 input tensors of SparseSlice:
+ (indices, values, shape, start, size)
+ The gradients for the indices, shape, start and the size are None.
+ """
+ backprop_val_grad = grads[1]
+ input_indices = op.inputs[0]
+ input_start = op.inputs[3]
+ output_indices = op.outputs[0]
+
+ val_grad = gen_sparse_ops.sparse_slice_grad(
+ backprop_val_grad, input_indices, input_start, output_indices)
+ val_grad.set_shape(op.inputs[1].get_shape())
+ # (indices, values, shape, start, size)
+ return (None, val_grad, None, None, None)
+
+
@ops.RegisterGradient("SparseTensorDenseMatMul")
def _SparseTensorDenseMatMulGrad(op, grad):
"""Gradients for the dense tensor in the SparseTensorDenseMatMul op.