diff options
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/estimator/model_fn.py | 4 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 30 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 9 | ||||
-rw-r--r-- | tensorflow/python/grappler/layout_optimizer_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/init_ops_test.py | 40 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/shape_ops_test.py | 23 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/sparse_slice_op_test.py | 22 | ||||
-rw-r--r-- | tensorflow/python/ops/array_grad.py | 8 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/init_ops.py | 24 | ||||
-rw-r--r-- | tensorflow/python/ops/losses/losses_impl.py | 3 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 3 | ||||
-rw-r--r-- | tensorflow/python/ops/sparse_grad.py | 29 |
14 files changed, 174 insertions, 27 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 009ac9d8fd..a9fd8f8e1a 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -99,7 +99,7 @@ class EstimatorSpec( ignored in eval and infer modes. Example: ```python - def my_model_fn(mode, features, labels): + def my_model_fn(features, labels, mode): predictions = ... loss = ... train_op = ... @@ -114,7 +114,7 @@ class EstimatorSpec( given mode. Example: ```python - def my_model_fn(mode, features, labels): + def my_model_fn(features, labels, mode): if (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL): loss = ... diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 89afd1d25b..cf0b1e36fb 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3239,8 +3239,9 @@ class Graph(object): # the name will still appear in _names_in_use even though the name hasn't # been used. This is ok, just leave _names_in_use as-is in this case. # TODO(skyewm): make the C API guarantee no name conflicts. - if ret.name not in self._names_in_use: - self._names_in_use[ret.name] = 1 + name_key = ret.name.lower() + if name_key not in self._names_in_use: + self._names_in_use[name_key] = 1 self._create_op_helper(ret, compute_device=compute_device) return ret @@ -3949,20 +3950,27 @@ class Graph(object): """ if self._name_stack: name = self._name_stack + "/" + name - i = self._names_in_use.get(name, 0) - # Increment the number for "name". + + # For the sake of checking for names in use, we treat names as case + # insensitive (e.g. foo = Foo). + name_key = name.lower() + i = self._names_in_use.get(name_key, 0) + # Increment the number for "name_key". if mark_as_used: - self._names_in_use[name] = i + 1 + self._names_in_use[name_key] = i + 1 if i > 0: - base_name = name - # Make sure the composed name is not already used. - while name in self._names_in_use: - name = "%s_%d" % (base_name, i) + base_name_key = name_key + # Make sure the composed name key is not already used. + while name_key in self._names_in_use: + name_key = "%s_%d" % (base_name_key, i) i += 1 - # Mark the composed name as used in case someone wants + # Mark the composed name_key as used in case someone wants # to call unique_name("name_1"). if mark_as_used: - self._names_in_use[name] = 1 + self._names_in_use[name_key] = 1 + + # Return the new name with the original capitalization of the given name. + name = "%s_%d" % (name, i-1) return name def get_name_scope(self): diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index c72406e92b..150100d771 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -965,6 +965,15 @@ class NameStackTest(test_util.TensorFlowTestCase): self.assertEqual("foo_1", g.unique_name("foo")) self.assertEqual("foo_3", g.unique_name("foo")) + def testUniqueNameCaseInsensitivity(self): + g = ops.Graph() + self.assertEqual("foo", g.unique_name("foo")) + self.assertEqual("Foo_1", g.unique_name("Foo")) + with g.name_scope("bar"): + self.assertEqual("bar/foo", g.unique_name("foo")) + with g.name_scope("Bar"): + self.assertEqual("Bar_1/foo", g.unique_name("foo")) + def testInvalidNameRaisesError(self): g = ops.Graph() with g.name_scope(""): # Should not raise diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 2c9f391d01..7d07c77c79 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -1390,7 +1390,7 @@ class LayoutOptimizerTest(test.TestCase): expected_num_transposes = 3 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes) - self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes) + self._assert_trans_nchw_to_nhwc('map/while/Add_1-0-2', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testLoopWithVecAnd4D(self): @@ -1414,7 +1414,7 @@ class LayoutOptimizerTest(test.TestCase): expected_num_transposes = 2 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes) - self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes) + self._assert_trans_nchw_to_nhwc('map/while/Add_1-0-2', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testBinaryOpSecondPort(self): diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 5796c874f9..8a6614c837 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -893,6 +893,7 @@ tf_py_test( "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", + "//tensorflow/python:sparse_grad", "//tensorflow/python:sparse_ops", ], ) diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index 795aa67248..927ca012ae 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -364,14 +364,52 @@ class UniformUnitScalingInitializationTest(test.TestCase): class VarianceScalingInitializationTest(test.TestCase): + def testTruncatedNormalDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer( + distribution='truncated_normal') + + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ + as mock_truncated_normal: + x = init(shape).eval() + self.assertTrue(mock_truncated_normal.called) + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + def testNormalDistribution(self): shape = [100, 100] expect_mean = 0. expect_var = 1. / shape[0] init = init_ops.variance_scaling_initializer(distribution='normal') - with self.test_session(use_gpu=True): + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ + as mock_truncated_normal: + x = init(shape).eval() + self.assertTrue(mock_truncated_normal.called) + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + + def testUntruncatedNormalDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer( + distribution='untruncated_normal') + + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'random_normal', wraps=random_ops.random_normal) \ + as mock_random_normal: x = init(shape).eval() + self.assertTrue(mock_random_normal.called) self.assertNear(np.mean(x), expect_mean, err=1e-2) self.assertNear(np.var(x), expect_var, err=1e-2) diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index 7368251ab6..34e34d9d1b 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -642,6 +642,29 @@ class TileTest(test.TestCase): err = gradient_checker.compute_gradient_error(a, [4, 2], tiled, [4, 4]) self.assertLess(err, 1e-3) + def testGradientWithSparseGradWithRank1(self): + inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], + dtype=dtypes.float32) + outputs = array_ops.gather(array_ops.tile(inputs, [3]), + [1, 5, 9, 3, 7, 2, 2, 2]) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + + def testGradientWithSparseGradWithRank3(self): + inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], + dtype=dtypes.float32) + inputs = array_ops.reshape(inputs, [-1, 1, 1]) + outputs = array_ops.gather(array_ops.tile(inputs, [3, 4, 2]), + [1, 5, 9, 3, 7, 2, 2, 2]) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + def testShapeFunctionEdgeCases(self): # Unknown multiples shape. inp = constant_op.constant(0.0, shape=[4, 4, 4, 4]) diff --git a/tensorflow/python/kernel_tests/sparse_slice_op_test.py b/tensorflow/python/kernel_tests/sparse_slice_op_test.py index da116601f8..97f30daf4a 100644 --- a/tensorflow/python/kernel_tests/sparse_slice_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_slice_op_test.py @@ -21,13 +21,15 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import sparse_ops +import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import from tensorflow.python.platform import test class SparseSliceOpTest(test.TestCase): - def _SparseTensor_4x6(self): + def _SparseTensor_4x6(self, val_dtype=np.int64): # [0 | |2 | |4 |5 ] # [ |11| |13|14| ] # [20| | |23| |25] @@ -37,7 +39,7 @@ class SparseSliceOpTest(test.TestCase): [2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype( np.int64) val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype( - np.int64) + val_dtype) shape = np.array([4, 6]).astype(np.int64) return sparse_tensor.SparseTensor(ind, val, shape) @@ -244,6 +246,22 @@ class SparseSliceOpTest(test.TestCase): self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35]) self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1]) + def testGradients(self): + sp_input = self._SparseTensor_4x6(val_dtype=np.float32) + start_and_size = [([0, 0], [4, 2]), + ([0, 2], [5, 2]), + ([0, 4], [5, 3])] + + with self.test_session(use_gpu=False): + for start, size in start_and_size: + sp_output = sparse_ops.sparse_slice(sp_input, start, size) + nnz_in = len(sp_input.values.eval()) + nnz_out = len(sp_output.values.eval()) + + err = gradient_checker.compute_gradient_error( + [sp_input.values], [(nnz_in,)], sp_output.values, (nnz_out,)) + self.assertLess(err, 1e-3) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 3678bd4c1f..fe459a96b9 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -568,7 +568,6 @@ ops.NotDifferentiable("Size") @ops.RegisterGradient("Tile") def _TileGrad(op, grad): """Sum reduces grad along the tiled dimensions.""" - assert isinstance(grad, ops.Tensor) input_shape = array_ops.shape(op.inputs[0]) # We interleave multiples and input_shape to get split_shape, # reshape grad to split_shape, and reduce along all even @@ -581,6 +580,13 @@ def _TileGrad(op, grad): split_shape = array_ops.reshape( array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1]) axes = math_ops.range(0, array_ops.size(split_shape), 2) + # Sum reduces grad along the first dimension for IndexedSlices + if isinstance(grad, ops.IndexedSlices): + grad = math_ops.unsorted_segment_sum( + grad.values, + math_ops.mod(grad.indices, input_shape[0]), + input_shape[0]) + split_shape = array_ops.concat([[1], split_shape[1:]], axis=0) input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) # Fix shape inference if not context.executing_eagerly(): diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index c8442b42d5..fc37805c79 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -3135,6 +3135,7 @@ def while_loop(cond, happen is that the thread updating `x` can never get ahead of the counter thread because the thread incrementing `x` depends on the value of the counter. + ```python import tensorflow as tf diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index c41e952167..5bfc5ce2a7 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -43,7 +43,8 @@ from tensorflow.python.ops import linalg_ops_impl from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.deprecation import ( + deprecated, deprecated_arg_values) from tensorflow.python.util.tf_export import tf_export @@ -409,8 +410,10 @@ class UniformUnitScaling(Initializer): class VarianceScaling(Initializer): """Initializer capable of adapting its scale to the shape of weights tensors. - With `distribution="normal"`, samples are drawn from a truncated normal - distribution centered on zero, with `stddev = sqrt(scale / n)` + With `distribution="truncated_normal" or "untruncated_normal"`, + samples are drawn from a truncated/untruncated normal + distribution with a mean of zero and a standard deviation (after truncation, + if used) `stddev = sqrt(scale / n)` where n is: - number of input units in the weight tensor, if mode = "fan_in" - number of output units, if mode = "fan_out" @@ -433,10 +436,14 @@ class VarianceScaling(Initializer): "distribution" arguments. """ + @deprecated_arg_values( + None, + "`normal` is a deprecated alias for `truncated_normal`", + distribution="normal") def __init__(self, scale=1.0, mode="fan_in", - distribution="normal", + distribution="truncated_normal", seed=None, dtype=dtypes.float32): if scale <= 0.: @@ -444,7 +451,8 @@ class VarianceScaling(Initializer): if mode not in {"fan_in", "fan_out", "fan_avg"}: raise ValueError("Invalid `mode` argument:", mode) distribution = distribution.lower() - if distribution not in {"normal", "uniform"}: + if distribution not in {"normal", "uniform", + "truncated_normal", "untruncated_normal"}: raise ValueError("Invalid `distribution` argument:", distribution) self.scale = scale self.mode = mode @@ -466,11 +474,15 @@ class VarianceScaling(Initializer): scale /= max(1., fan_out) else: scale /= max(1., (fan_in + fan_out) / 2.) - if self.distribution == "normal": + if self.distribution == "normal" or self.distribution == "truncated_normal": # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) stddev = math.sqrt(scale) / .87962566103423978 return random_ops.truncated_normal( shape, 0.0, stddev, dtype, seed=self.seed) + elif self.distribution == "untruncated_normal": + stddev = math.sqrt(scale) + return random_ops.random_normal( + shape, 0.0, stddev, dtype, seed=self.seed) else: limit = math.sqrt(3.0 * scale) return random_ops.random_uniform( diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 9ba91772f5..66633c8b12 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -878,7 +878,8 @@ def sparse_softmax_cross_entropy( exception when this op is run on CPU, and return `NaN` for corresponding loss and gradient rows on GPU. logits: Unscaled log probabilities of shape - `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`. + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32` or + `float64`. weights: Coefficients for the loss. This must be scalar or broadcastable to `labels` (i.e. same rank and each dimension is either 1 or the same). scope: the scope for the operations performed in computing the loss. diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 5a3b669c28..41d54a6c2f 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -2009,7 +2009,8 @@ def sparse_softmax_cross_entropy_with_logits( exception when this op is run on CPU, and return `NaN` for corresponding loss and gradient rows on GPU. logits: Unscaled log probabilities of shape - `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`. + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or + `float64`. name: A name for the operation (optional). Returns: diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py index 97353d6c74..1223b290ff 100644 --- a/tensorflow/python/ops/sparse_grad.py +++ b/tensorflow/python/ops/sparse_grad.py @@ -116,6 +116,35 @@ def _SparseReduceSumGrad(op, out_grad): None, None) +@ops.RegisterGradient("SparseSlice") +def _SparseSliceGrad(op, *grads): + """The backward operator for the SparseSlice op. + + This op takes in the upstream gradient w.r.t. non-empty values of + the sliced `SparseTensor`, and outputs the gradients w.r.t. + the non-empty values of input `SparseTensor`. + + Args: + op: the SparseSlice op + *grads: the incoming gradients, one element per output of `op` + + Returns: + Gradient for each of the 5 input tensors of SparseSlice: + (indices, values, shape, start, size) + The gradients for the indices, shape, start and the size are None. + """ + backprop_val_grad = grads[1] + input_indices = op.inputs[0] + input_start = op.inputs[3] + output_indices = op.outputs[0] + + val_grad = gen_sparse_ops.sparse_slice_grad( + backprop_val_grad, input_indices, input_start, output_indices) + val_grad.set_shape(op.inputs[1].get_shape()) + # (indices, values, shape, start, size) + return (None, val_grad, None, None, None) + + @ops.RegisterGradient("SparseTensorDenseMatMul") def _SparseTensorDenseMatMulGrad(op, grad): """Gradients for the dense tensor in the SparseTensorDenseMatMul op. |