aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/array_ops.py2
-rw-r--r--tensorflow/python/ops/bitwise_ops.py1
-rw-r--r--tensorflow/python/ops/bitwise_ops_test.py23
-rw-r--r--tensorflow/python/ops/control_flow_ops.py10
-rw-r--r--tensorflow/python/ops/init_ops.py1
-rw-r--r--tensorflow/python/ops/nn_test.py18
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py18
7 files changed, 63 insertions, 10 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index bb86640eab..f64c89ac5d 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -330,7 +330,7 @@ def rank(input, name=None):
# pylint: disable=redefined-builtin
"""Returns the rank of a tensor.
- This operation returns an integer representing the rank of `input`.
+ Returns a 0-D `int32` `Tensor` representing the rank of `input`.
For example:
diff --git a/tensorflow/python/ops/bitwise_ops.py b/tensorflow/python/ops/bitwise_ops.py
index cbabc3ed9b..44daf13537 100644
--- a/tensorflow/python/ops/bitwise_ops.py
+++ b/tensorflow/python/ops/bitwise_ops.py
@@ -36,5 +36,6 @@ ops.NotDifferentiable("BitwiseAnd")
ops.NotDifferentiable("BitwiseOr")
ops.NotDifferentiable("BitwiseXor")
ops.NotDifferentiable("Invert")
+ops.NotDifferentiable("PopulationCount")
remove_undocumented(__name__)
diff --git a/tensorflow/python/ops/bitwise_ops_test.py b/tensorflow/python/ops/bitwise_ops_test.py
index 904cf99a5a..1d08c8f82d 100644
--- a/tensorflow/python/ops/bitwise_ops_test.py
+++ b/tensorflow/python/ops/bitwise_ops_test.py
@@ -18,10 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+import six
+
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import bitwise_ops
+from tensorflow.python.ops import gen_bitwise_ops
from tensorflow.python.platform import googletest
@@ -46,6 +50,25 @@ class BitwiseOpTest(test_util.TensorFlowTestCase):
self.assertAllEqual(or_result, [5, 5, 7, 15])
self.assertAllEqual(xor_result, [5, 5, 4, 5])
+ def testPopulationCountOp(self):
+ dtype_list = [dtypes.int8, dtypes.int16,
+ dtypes.int32, dtypes.int64,
+ dtypes.uint8, dtypes.uint16]
+ raw_inputs = [0, 1, -1, 3, -3, 5, -5, 14, -14,
+ 127, 128, 255, 256, 65535, 65536,
+ 2**31 - 1, 2**31, 2**32 - 1, 2**32, -2**32 + 1, -2**32,
+ -2**63 + 1, 2**63 - 1]
+ def count_bits(x):
+ return sum([bin(z).count("1") for z in six.iterbytes(x.tobytes())])
+ for dtype in dtype_list:
+ with self.test_session(use_gpu=True) as sess:
+ print("PopulationCount test: ", dtype)
+ inputs = np.array(raw_inputs, dtype=dtype.as_numpy_dtype)
+ truth = [count_bits(x) for x in inputs]
+ input_tensor = constant_op.constant(inputs, dtype=dtype)
+ popcnt_result = sess.run(gen_bitwise_ops.population_count(input_tensor))
+ self.assertAllEqual(truth, popcnt_result)
+
def testInvertOp(self):
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.uint8, dtypes.uint16]
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 44d6c7e275..4ba812eaf5 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -61,6 +61,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_control_flow_ops
@@ -983,9 +984,16 @@ class GradLoopState(object):
# the right control flow context.
real_value = self._grad_context.AddValue(cur_value)
break
+ elif constant_op.is_constant(cur_value):
+ # If the value to be forwarded is a constant, clone the constant in
+ # the gradient loop rather than using a stack.
+ # TODO(phawkins): consider hoisting the constant out of the loop
+ # instead.
+ real_value = constant_op.constant(
+ tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
+ break
else:
# Record the history of this value in forward_ctxt.
- # TODO(yuanbyu): Avoid recording constants.
self._grad_context.Exit()
history_value = cur_grad_state.AddForwardAccumulator(cur_value)
self._grad_context.Enter()
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 1e2f999995..42b4f952bb 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -41,7 +41,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import math_ops
class Initializer(object):
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 87f6f92a8a..cc8c623947 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -830,7 +830,8 @@ class ReluTest(test_lib.TestCase):
class MomentsTest(test_lib.TestCase):
- def doOutputTest(self, input_shape, moments_axes, tol=1e-4):
+ def doOutputTest(self, input_shape, moments_axes, tol=1e-4,
+ check_gradients=False):
for mu in [0.0, 1.0, 1e3]:
for sigma in [1.0, 0.1]:
for keep_dims in [True, False]:
@@ -846,6 +847,15 @@ class MomentsTest(test_lib.TestCase):
mean, variance = nn_impl.moments(
inputs, moments_axes, keep_dims=keep_dims)
+ if check_gradients:
+ err = gradient_checker.compute_gradient_error(
+ inputs, input_shape, mean, mean.shape.as_list())
+ self.assertLess(err, 1e-3)
+ err = gradient_checker.compute_gradient_error(
+ inputs, input_shape, variance, variance.shape.as_list())
+ self.assertLess(err, 1e-3)
+
+ # Evaluate.
[mean, variance] = sess.run([mean, variance])
# Make sure that there are no NaNs
self.assertFalse(np.isnan(mean).any())
@@ -853,6 +863,12 @@ class MomentsTest(test_lib.TestCase):
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
self.assertAllClose(variance, expected_var, rtol=tol, atol=tol)
+ def testOutputAndGradient2DInput0(self):
+ self.doOutputTest((10, 10), (0,), check_gradients=True)
+
+ def testOutputAndGradient2DInput01(self):
+ self.doOutputTest((10, 10), (0, 1), check_gradients=True)
+
def testOutput2DInput0(self):
self.doOutputTest((10, 300), (0,))
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index f7854e86c0..304b6ae665 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -786,13 +786,18 @@ class DropoutWrapper(RNNCell):
class ResidualWrapper(RNNCell):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
- def __init__(self, cell):
+ def __init__(self, cell, residual_fn=None):
"""Constructs a `ResidualWrapper` for `cell`.
Args:
cell: An instance of `RNNCell`.
+ residual_fn: (Optional) The function to map raw cell inputs and raw cell
+ outputs to the actual cell outputs of the residual network.
+ Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs
+ and outputs.
"""
self._cell = cell
+ self._residual_fn = residual_fn
@property
def state_size(self):
@@ -807,7 +812,7 @@ class ResidualWrapper(RNNCell):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
- """Run the cell and add its inputs to its outputs.
+ """Run the cell and then apply the residual_fn on its inputs to its outputs.
Args:
inputs: cell inputs.
@@ -822,13 +827,14 @@ class ResidualWrapper(RNNCell):
ValueError: If cell inputs and outputs have different structure (value).
"""
outputs, new_state = self._cell(inputs, state, scope=scope)
- nest.assert_same_structure(inputs, outputs)
# Ensure shapes match
def assert_shape_match(inp, out):
inp.get_shape().assert_is_compatible_with(out.get_shape())
- nest.map_structure(assert_shape_match, inputs, outputs)
- res_outputs = nest.map_structure(
- lambda inp, out: inp + out, inputs, outputs)
+ def default_residual_fn(inputs, outputs):
+ nest.assert_same_structure(inputs, outputs)
+ nest.map_structure(assert_shape_match, inputs, outputs)
+ return nest.map_structure(lambda inp, out: inp + out, inputs, outputs)
+ res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs)
return (res_outputs, new_state)