aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Huazuo Gao <gaohuazuo@gmail.com>2016-06-28 11:13:44 +0800
committerGravatar Huazuo Gao <gaohuazuo@gmail.com>2016-06-28 11:13:44 +0800
commit43925959ebcbf6eb5e48d8854c9550a109961aa7 (patch)
treee2829db19176c1cd3dcc0d56caa5105d1024dd0c /tensorflow/python
parent19491230eed62fae0bdaf5c02417af389b314658 (diff)
parent14b8ed02dbd4da8fd7a269fa6a5fef5abe405489 (diff)
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD3
-rw-r--r--tensorflow/python/client/session.py30
-rw-r--r--tensorflow/python/client/tf_session.i16
-rw-r--r--tensorflow/python/client/tf_session_helper.cc7
-rw-r--r--tensorflow/python/client/tf_session_helper.h4
-rw-r--r--tensorflow/python/framework/dtypes.py2
-rw-r--r--tensorflow/python/framework/framework_lib.py2
-rw-r--r--tensorflow/python/framework/ops.py84
-rw-r--r--tensorflow/python/framework/tensor_util.py52
-rw-r--r--tensorflow/python/framework/tensor_util_test.py33
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py105
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py12
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py177
-rw-r--r--tensorflow/python/kernel_tests/pack_op_test.py57
-rw-r--r--tensorflow/python/kernel_tests/padding_fifo_queue_test.py195
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py19
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py109
-rw-r--r--tensorflow/python/kernel_tests/unpack_op_test.py68
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py22
-rw-r--r--tensorflow/python/ops/array_grad.py27
-rw-r--r--tensorflow/python/ops/array_ops.py125
-rw-r--r--tensorflow/python/ops/control_flow_grad.py2
-rw-r--r--tensorflow/python/ops/control_flow_ops.py83
-rw-r--r--tensorflow/python/ops/data_flow_ops.py18
-rw-r--r--tensorflow/python/ops/image_grad.py63
-rw-r--r--tensorflow/python/ops/image_grad_test.py129
-rw-r--r--tensorflow/python/ops/image_ops.py22
-rw-r--r--tensorflow/python/ops/math_ops.py30
-rw-r--r--tensorflow/python/ops/nn_conv_test.py33
-rw-r--r--tensorflow/python/ops/nn_grad.py21
-rw-r--r--tensorflow/python/ops/rnn.py119
-rw-r--r--tensorflow/python/ops/variables.py9
-rw-r--r--tensorflow/python/platform/test.py6
-rw-r--r--tensorflow/python/training/moving_averages.py6
-rw-r--r--tensorflow/python/training/saver.py2
-rw-r--r--tensorflow/python/training/server_lib_test.py148
36 files changed, 1718 insertions, 122 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index e088fd61d1..ef7754a920 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -253,6 +253,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":device_lib",
":framework_test_lib",
":platform_test",
],
@@ -543,10 +544,12 @@ tf_gen_op_wrapper_py(
"HashTable",
"InitializeTable",
"InitializeTableFromTextFile",
+ "LookupTableExport",
"LookupTableFind",
"LookupTableInsert",
"LookupTableSize",
"MutableHashTable",
+ "MutableHashTableOfTensors",
"Mutex",
"MutexAcquire",
"MutexRelease",
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 6dd213624b..a7de16de0f 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -853,6 +853,8 @@ class Session(BaseSession):
@@as_default
+ @@reset
+
"""
def __init__(self, target='', graph=None, config=None):
@@ -892,6 +894,34 @@ class Session(BaseSession):
self.close()
+ @staticmethod
+ def reset(target, containers=None, config=None):
+ """Resets resource containers on `target`, and close all connected sessions.
+
+ A resource container is distributed across all workers in the
+ same cluster as `target`. When a resource container on `target`
+ is reset, resources associated with that container will be cleared.
+ In particular, all Variables in the container will become undefined:
+ they lose their values and shapes.
+
+ NOTE:
+ (i) reset() is currently only implemented for distributed sessions.
+ (ii) Any sessions on the master named by `target` will be closed.
+
+ If no resource containers are provided, all containers are reset.
+
+ Args:
+ target: The execution engine to connect to.
+ containers: A list of resource container name strings, or `None` if all of
+ all the containers are to be reset.
+ config: (Optional.) Protocol buffer with configuration options.
+
+ Raises:
+ tf.errors.OpError: Or one of its subclasses if an error occurs while
+ resetting containers.
+ """
+ tf_session.TF_Reset(target, containers, config)
+
class InteractiveSession(BaseSession):
"""A TensorFlow `Session` for use in interactive contexts, such as a shell.
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 6707542178..01c3bf7e51 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -290,6 +290,22 @@ tensorflow::ImportNumpy();
%unignore tensorflow;
%unignore TF_PRun;
+%unignore tensorflow::TF_Reset_wrapper;
+%insert("python") %{
+def TF_Reset(target, containers=None, config=None):
+ from tensorflow.python.framework import errors
+ try:
+ opts = TF_NewSessionOptions(target=target, config=config)
+ with errors.raise_exception_on_not_ok_status() as status:
+ from tensorflow.python.util import compat
+ if containers is None:
+ containers = []
+ TF_Reset_wrapper(
+ opts, [compat.as_bytes(c) for c in containers], status)
+ finally:
+ TF_DeleteSessionOptions(opts)
+%}
+
%include "tensorflow/python/client/tf_session_helper.h"
%unignoreall
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index e9b6b44613..31f4ad6ba9 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -573,6 +573,13 @@ void TF_PRun_wrapper(TF_Session* session, const char* handle,
NameVector(), out_status, out_values, nullptr);
}
+// Wrapper for TF_Reset that converts the string vectors to character arrays.
+void TF_Reset_wrapper(const TF_SessionOptions* opt,
+ const NameVector& containers, TF_Status* out_status) {
+ TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
+ out_status);
+}
+
string EqualGraphDefWrapper(const string& actual, const string& expected) {
GraphDef actual_def;
if (!actual_def.ParseFromString(actual)) {
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index d360efa8f3..285bd902a4 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -96,6 +96,10 @@ void TF_PRun_wrapper(TF_Session* session, const char* handle,
const FeedVector& inputs, const NameVector& output_names,
TF_Status* out_status, PyObjectVector* out_values);
+// Wrapper for TF_Reset that converts the string vectors to character arrays.
+void TF_Reset_wrapper(const TF_SessionOptions* opt,
+ const NameVector& containers, TF_Status* out_status);
+
// Convenience wrapper around EqualGraphDef to make it easier to wrap.
// Returns an explanation if a difference is found, or the empty string
// for no difference.
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index d45ab2f2e2..1f29426b4c 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -205,7 +205,7 @@ class DType(object):
(bool, string, complex64, complex128)):
raise TypeError("Cannot find maximum value of %s." % self)
- # there is no simple way to get the min value of a dtype, we have to check
+ # there is no simple way to get the max value of a dtype, we have to check
# float and int types separately
try:
return np.finfo(self.as_numpy_dtype()).max
diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py
index 480e99abae..b06605cf59 100644
--- a/tensorflow/python/framework/framework_lib.py
+++ b/tensorflow/python/framework/framework_lib.py
@@ -30,6 +30,7 @@
## Utility functions
@@device
+@@container
@@name_scope
@@control_dependencies
@@convert_to_tensor
@@ -77,6 +78,7 @@ from tensorflow.python.framework.ops import IndexedSlices
# Utilities used when building a Graph.
from tensorflow.python.framework.ops import device
+from tensorflow.python.framework.ops import container
from tensorflow.python.framework.ops import name_scope
from tensorflow.python.framework.ops import op_scope
from tensorflow.python.framework.ops import control_dependencies
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 912928778b..b7d131cf78 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -35,6 +35,7 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import versions_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
@@ -1984,6 +1985,9 @@ class Graph(object):
self._handle_movers = {}
# A map from tensor handle to its delete op.
self._handle_deleters = {}
+ # Resource container.
+ self._container = ""
+ self._registered_ops = op_def_registry.get_registered_ops()
def _check_not_finalized(self):
"""Check if the graph is finalized.
@@ -2320,6 +2324,19 @@ class Graph(object):
ret.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
+ # Sets "container" attribute if
+ # (1) self._container is not None
+ # (2) "is_stateful" is set in OpDef
+ # (3) "container" attribute is in OpDef
+ # (4) "container" attribute is None
+ if (self._container and
+ op_type in self._registered_ops and
+ self._registered_ops[op_type].is_stateful and
+ "container" in ret.node_def.attr and
+ not ret.node_def.attr["container"].s):
+ ret.node_def.attr["container"].CopyFrom(
+ attr_value_pb2.AttrValue(s=compat.as_bytes(self._container)))
+
return ret
def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
@@ -2958,6 +2975,60 @@ class Graph(object):
break
op._set_device(device_function(op))
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def container(self, container_name):
+ """Returns a context manager that specifies the resource container to use.
+
+ Stateful operations, such as variables and queues, can maintain their
+ states on devices so that they can be shared by multiple processes.
+ A resource container is a string name under which these stateful
+ operations are tracked. These resources can be released or cleared
+ with `tf.Session.reset()`.
+
+ For example:
+
+ ```python
+ with g.container('experiment0'):
+ # All stateful Operations constructed in this context will be placed
+ # in resource container "experiment0".
+ v1 = tf.Variable([1.0])
+ v2 = tf.Variable([2.0])
+ with g.container("experiment1"):
+ # All stateful Operations constructed in this context will be
+ # placed in resource container "experiment1".
+ v3 = tf.Variable([3.0])
+ q1 = tf.FIFOQueue(10, tf.float32)
+ # All stateful Operations constructed in this context will be
+ # be created in the "experiment0".
+ v4 = tf.Variable([4.0])
+ q1 = tf.FIFOQueue(20, tf.float32)
+ with g.container(""):
+ # All stateful Operations constructed in this context will be
+ # be placed in the default resource container.
+ v5 = tf.Variable([5.0])
+ q3 = tf.FIFOQueue(30, tf.float32)
+
+ # Resets container "experiment0", after which the state of v1, v2, v4, q1
+ # will become undefined (such as unitialized).
+ tf.Session.reset(target, ["experiment0"])
+ ```
+
+ Args:
+ container_name: container name string.
+
+ Returns:
+ A context manager for defining resource containers for stateful ops,
+ yields the container name.
+ """
+ original_container = self._container
+ try:
+ self._container = container_name
+ yield self._container
+ finally:
+ self._container = original_container
+ # pylint: enable=g-doc-return-or-yield
+
class _ControlDependenciesController(object):
"""Context manager for `control_dependencies()`."""
@@ -3335,6 +3406,19 @@ def device(device_name_or_function):
return get_default_graph().device(device_name_or_function)
+def container(container_name):
+ """Wrapper for `Graph.container()` using the default graph.
+
+ Args:
+ container_name: The container string to use in the context.
+
+ Returns:
+ A context manager that specifies the default container to use for newly
+ created stateful ops.
+ """
+ return get_default_graph().container(container_name)
+
+
def colocate_with(op, ignore_existing=False):
return get_default_graph().colocate_with(op, ignore_existing)
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index a213dc2262..098558fd3e 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -628,3 +628,55 @@ def constant_value(tensor):
# conservatively prevent it from being fed.
tensor.graph.prevent_feeding(tensor)
return ret
+
+
+def constant_value_as_shape(tensor): # pylint: disable=invalid-name
+ """A version of `constant_value()` that returns a `TensorShape`.
+
+ This version should be used when a constant tensor value is
+ interpreted as a (possibly partial) shape, e.g. in the shape
+ function for `tf.reshape()`. By explicitly requesting a
+ `TensorShape` as the return value, it is possible to represent
+ unknown dimensions; by contrast, `constant_value()` is
+ all-or-nothing.
+
+ Args:
+ tensor: The rank-1 Tensor to be evaluated.
+
+ Returns:
+ A `TensorShape` based on the constant value of the given `tensor`.
+ """
+ shape = tensor.get_shape().with_rank(1)
+ if tensor.get_shape() == [0]:
+ return tensor_shape.scalar()
+ elif tensor.op.type == "Shape":
+ return tensor.op.inputs[0].get_shape()
+ elif tensor.op.type == "Pack":
+ ret = tensor_shape.scalar() # Empty list.
+ for pack_input in tensor.op.inputs:
+ # `pack_input` must be a scalar. Attempt to evaluate it, and append it
+ # to `ret`.
+ pack_input_val = constant_value(pack_input)
+ if pack_input_val is None or pack_input_val < 0:
+ new_dim = tensor_shape.Dimension(None)
+ else:
+ new_dim = tensor_shape.Dimension(pack_input_val)
+ ret = ret.concatenate([new_dim])
+ return ret
+ elif tensor.op.type == "Concat":
+ # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is
+ # the only legal value when concatenating vectors, and it will
+ # have been checked by a previous shape function.
+ ret = tensor_shape.scalar() # Empty list.
+ for concat_input in tensor.op.inputs[1:]:
+ # `concat_input` must be a vector. Attempt to evaluate it as a shape,
+ # and concatenate it with `ret`.
+ ret = ret.concatenate(constant_value_as_shape(concat_input))
+ return ret
+ else:
+ ret = tensor_shape.unknown_shape(shape[0].value)
+ value = constant_value(tensor)
+ if value is not None:
+ ret = ret.merge_with(tensor_shape.TensorShape(
+ [d if d != -1 else None for d in value]))
+ return ret
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index c56c5e948b..3fdcba5416 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -565,5 +565,38 @@ class ConstantValueTest(tf.test.TestCase):
self.assertIs(None, c_val)
+class ConstantValueAsShapeTest(tf.test.TestCase):
+
+ def testConstant(self):
+ np_val = np.random.rand(3).astype(np.int32)
+ tf_val = tf.constant(np_val)
+ self.assertEqual(tf.TensorShape(np_val),
+ tensor_util.constant_value_as_shape(tf_val))
+
+ tf_val = tf.constant([], dtype=tf.int32)
+ self.assertEqual(tf.TensorShape([]),
+ tensor_util.constant_value_as_shape(tf_val))
+
+ def testShape(self):
+ tf_val = tf.shape(tf.constant(0.0, shape=[1, 2, 3]))
+ c_val = tensor_util.constant_value_as_shape(tf_val)
+ self.assertEqual(tf.TensorShape([1, 2, 3]), c_val)
+
+ def testPack(self):
+ tf_val = tf.pack([tf.constant(16), 37, tf.placeholder(tf.int32)])
+ c_val = tensor_util.constant_value_as_shape(tf_val)
+ self.assertEqual([16, 37, None], c_val.as_list())
+
+ def testConcat(self):
+ tf_val = tf.concat(0, [[16, 37], tf.placeholder(tf.int32, shape=(2,))])
+ c_val = tensor_util.constant_value_as_shape(tf_val)
+ self.assertEqual([16, 37, None, None], c_val.as_list())
+
+ tf_val = tf.concat(0,
+ [[16, 37], tf.placeholder(tf.int32, shape=(1,)), [48]])
+ c_val = tensor_util.constant_value_as_shape(tf_val)
+ self.assertEqual([16, 37, None, 48], c_val.as_list())
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 4353de4007..8435d6b9cd 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import math
+import time
import numpy as np
import tensorflow as tf
@@ -28,7 +29,6 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import googletest
class BooleanMaskTest(test_util.TensorFlowTestCase):
@@ -383,5 +383,106 @@ class StridedSliceShapeTest(test_util.TensorFlowTestCase):
tensor_shape.TensorShape([5, None, 1, 4]))
+class GradSliceChecker(object):
+ """Tests that we can compute a gradient for var^2."""
+
+ def __init__(self, test, sess, var, varnp):
+ self.test = test
+ self.sess = sess
+ self.val = var * var
+ self.var = var
+ self.varnp = varnp
+
+ def __getitem__(self, spec):
+ val_grad_op = tf.gradients(self.val, self.var)
+ sliceval_grad_op = tf.gradients(
+ array_ops._NewSliceHelper(self.val, spec), self.var)
+ slice1_op = array_ops._NewSliceHelper(val_grad_op, spec)
+ slice2_op = array_ops._NewSliceHelper(sliceval_grad_op, spec)
+ val_grad, sliceval_grad, slice1, slice2 = self.sess.run(
+ [val_grad_op, sliceval_grad_op, slice1_op, slice2_op])
+ np_val_grad = (2 * self.varnp)
+ np_sliceval_grad = np.zeros(self.var.get_shape())
+ np_sliceval_grad[spec] = np.array(val_grad[0])[spec]
+ # make sure np val grad is correct
+ self.test.assertAllEqual(np_val_grad, val_grad[0])
+ # make sure slice gradient is correct
+ self.test.assertAllEqual(np_sliceval_grad, sliceval_grad[0])
+ # make sure val grad and sliceval grad are the same in sliced area
+ self.test.assertAllEqual(slice1, slice2)
+
+
+class StridedSliceGradTest(test_util.TensorFlowTestCase):
+ """Test that strided slice's custom gradient produces correct gradients."""
+
+ def testGradient(self):
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ var = tf.Variable(tf.reshape(tf.range(1, 97, 1), shape=(6, 4, 4)))
+ init = tf.initialize_all_variables()
+ sess.run(init)
+
+ grad = GradSliceChecker(self, sess, var,
+ np.array(range(1, 97, 1)).reshape((6, 4, 4)))
+ _ = grad[2:6:2, 1:3, 1:3]
+ _ = grad[3:0:-2, 1:3, 1:3]
+ _ = grad[3:0:-2, tf.newaxis, 1:3, 2, tf.newaxis]
+ _ = grad[3:0:-2, 1:3, 2]
+
+
+class BenchmarkSlice(object):
+
+ def __init__(self, tensor):
+ self.tensor = tensor
+
+ def __getitem__(self, x):
+ return array_ops._NewSliceHelper(self.tensor, x)
+
+
+class StridedSliceBenchmark(tf.test.Benchmark):
+ """Benchmark new strided slice operation on non-trivial case."""
+
+ def run_and_time(self, slice_op):
+ tf.initialize_all_variables().run()
+ for _ in range(10):
+ _ = slice_op.eval()
+ iters = 1000
+ t0 = time.time()
+ for _ in range(iters):
+ slice_op.eval()
+ t1 = time.time()
+ self.report_benchmark(iters=iters, wall_time=(t1 - t0) / 1000.0)
+
+ def make_variable(self):
+ n = 256
+ shape = (n, n, n)
+ items = n**3
+ var = tf.Variable(
+ tf.reshape(
+ tf.linspace(1., float(items), items), shape),
+ dtype=tf.float32)
+ return var
+
+ def benchmark_strided_slice_skip(self):
+ with tf.Session():
+ var = self.make_variable()
+ helper = BenchmarkSlice(var)
+ slice_op = helper[::2, ::1, ::2]
+ self.run_and_time(slice_op)
+
+ def benchmark_strided_slice_easy(self):
+ with tf.Session():
+ var = self.make_variable()
+ helper = BenchmarkSlice(var)
+ slice_op = helper[3::1, 3::1, 3::1]
+ self.run_and_time(slice_op)
+
+ def benchmark_slice_easy(self):
+ with tf.Session():
+ var = self.make_variable()
+ slice_op = var[3::1, 3::1, 3::1]
+ self.run_and_time(slice_op)
+
+
if __name__ == "__main__":
- googletest.main()
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 96fc24fba0..f1a1683332 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -950,6 +950,18 @@ class ControlFlowTest(tf.test.TestCase):
self.assertEqual([None], r.get_shape().as_list())
self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
+ def testWhileGrad_BaseShape(self):
+ with self.test_session() as sess:
+ x = tf.placeholder(tf.float32, [None])
+ v0 = tf.constant([2.0, 2.0], name="v")
+ c = lambda v: tf.constant(False)
+ b = lambda v: tf.mul(v, x)
+ r = tf.while_loop(c, b, [v0])
+ y = tf.square(x)
+
+ r = tf.gradients([r, y], x)[0]
+ self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
+
def testWhileGrad_MultipleUses(self):
with self.test_session():
v = tf.constant(2.0, name="v")
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index 622459e545..8ec04d3e80 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import random
import re
import time
-
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -293,6 +292,16 @@ class FIFOQueueTest(tf.test.TestCase):
enqueue_op.run()
self.assertEqual([], dequeued_t.eval().tolist())
+ def testEmptyDequeueUpTo(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32, shapes=())
+ enqueue_op = q.enqueue((10.0,))
+ dequeued_t = q.dequeue_up_to(0)
+
+ self.assertEqual([], dequeued_t.eval().tolist())
+ enqueue_op.run()
+ self.assertEqual([], dequeued_t.eval().tolist())
+
def testEmptyDequeueManyWithNoShape(self):
with self.test_session():
q = tf.FIFOQueue(10, tf.float32)
@@ -328,6 +337,18 @@ class FIFOQueueTest(tf.test.TestCase):
self.assertAllEqual(elems[0:4], dequeued_t.eval())
self.assertAllEqual(elems[4:8], dequeued_t.eval())
+ def testDequeueUpToNoBlocking(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32, ())
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_up_to(4)
+
+ enqueue_op.run()
+
+ self.assertAllEqual(elems[0:4], dequeued_t.eval())
+ self.assertAllEqual(elems[4:8], dequeued_t.eval())
+
def testMultiDequeueMany(self):
with self.test_session() as sess:
q = tf.FIFOQueue(10, (tf.float32, tf.int32),
@@ -358,6 +379,29 @@ class FIFOQueueTest(tf.test.TestCase):
self.assertEqual(float_val.shape, dequeued_single_t[0].get_shape())
self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape())
+ def testMultiDequeueUpToNoBlocking(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, (tf.float32, tf.int32),
+ shapes=((), (2,)))
+ float_elems = [
+ 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ int_elems = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
+ [11, 12], [13, 14], [15, 16], [17, 18], [19, 20]]
+ enqueue_op = q.enqueue_many((float_elems, int_elems))
+ dequeued_t = q.dequeue_up_to(4)
+
+ enqueue_op.run()
+
+ float_val, int_val = sess.run(dequeued_t)
+ self.assertAllEqual(float_elems[0:4], float_val)
+ self.assertAllEqual(int_elems[0:4], int_val)
+ self.assertEqual([None], dequeued_t[0].get_shape().as_list())
+ self.assertEqual([None, 2], dequeued_t[1].get_shape().as_list())
+
+ float_val, int_val = sess.run(dequeued_t)
+ self.assertAllEqual(float_elems[4:8], float_val)
+ self.assertAllEqual(int_elems[4:8], int_val)
+
def testHighDimension(self):
with self.test_session():
q = tf.FIFOQueue(10, tf.int32, (4, 4, 4, 4))
@@ -469,6 +513,29 @@ class FIFOQueueTest(tf.test.TestCase):
thread.join()
self.assertItemsEqual(elems, dequeued_elems)
+ def testParallelDequeueUpTo(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(1000, tf.float32, shapes=())
+ elems = [10.0 * x for x in range(1000)]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue_up_to(101)
+
+ enqueue_op.run()
+ close_op.run()
+
+ # Dequeue up to 101 items in parallel on 10 threads, from closed queue.
+ dequeued_elems = []
+
+ def dequeue():
+ dequeued_elems.extend(sess.run(dequeued_t))
+ threads = [self.checkedThread(target=dequeue) for _ in range(10)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ self.assertItemsEqual(elems, dequeued_elems)
+
def testParallelEnqueueAndDequeue(self):
with self.test_session() as sess:
q = tf.FIFOQueue(50, tf.float32, shapes=())
@@ -596,6 +663,33 @@ class FIFOQueueTest(tf.test.TestCase):
self.assertAllEqual(elems, dequeued_elems)
+ def testBlockingDequeueUpTo(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32, ())
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_up_to(4)
+
+ dequeued_elems = []
+
+ def enqueue():
+ # The enqueue_op should run after the dequeue op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ sess.run(enqueue_op)
+
+ def dequeue():
+ dequeued_elems.extend(sess.run(dequeued_t).tolist())
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ enqueue_thread.start()
+ dequeue_thread.start()
+ enqueue_thread.join()
+ dequeue_thread.join()
+
+ self.assertAllEqual(elems, dequeued_elems)
+
def testDequeueManyWithTensorParameter(self):
with self.test_session():
# Define a first queue that contains integer counts.
@@ -710,6 +804,53 @@ class FIFOQueueTest(tf.test.TestCase):
close_op.run()
dequeue_thread.join()
+ def testBlockingDequeueManyButNotAllFromClosedQueue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32, ())
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue_many(3)
+
+ enqueue_op.run()
+
+ def dequeue():
+ self.assertAllEqual(elems[:3], sess.run(dequeued_t))
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
+ def testDequeueUpToFromClosedQueueReturnsRemainder(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32, ())
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue_up_to(3)
+
+ enqueue_op.run()
+
+ def dequeue():
+ self.assertAllEqual(elems[:3], sess.run(dequeued_t))
+ self.assertAllEqual(elems[3:], sess.run(dequeued_t))
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
with self.test_session() as sess:
q = tf.FIFOQueue(4, tf.float32, ())
@@ -788,7 +929,27 @@ class FIFOQueueTest(tf.test.TestCase):
def dequeue():
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
- "is closed and has insufficient"):
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
+ def testBlockingDequeueUpToFromClosedEmptyQueue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32, ())
+ close_op = q.close()
+ dequeued_t = q.dequeue_up_to(4)
+
+ def dequeue():
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
sess.run(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -1341,6 +1502,8 @@ class FIFOQueueWithTimeoutTest(tf.test.TestCase):
with self.test_session(
config=tf.ConfigProto(operation_timeout_in_ms=20)) as sess:
q = tf.FIFOQueue(10, tf.float32)
+ self.assertEqual(tf.compat.as_bytes(""),
+ q.queue_ref.op.get_attr("container"))
dequeued_t = q.dequeue()
# Intentionally do not run any enqueue_ops so that dequeue will block
@@ -1350,5 +1513,15 @@ class FIFOQueueWithTimeoutTest(tf.test.TestCase):
sess.run(dequeued_t)
+class QueueContainerTest(tf.test.TestCase):
+
+ def testContainer(self):
+ with tf.Graph().as_default():
+ with tf.container("test"):
+ q = tf.FIFOQueue(10, tf.float32)
+ self.assertEqual(tf.compat.as_bytes("test"),
+ q.queue_ref.op.get_attr("container"))
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/pack_op_test.py b/tensorflow/python/kernel_tests/pack_op_test.py
index 349def7181..5d7055824d 100644
--- a/tensorflow/python/kernel_tests/pack_op_test.py
+++ b/tensorflow/python/kernel_tests/pack_op_test.py
@@ -22,6 +22,14 @@ import numpy as np
import tensorflow as tf
+def np_split_sqeeze(array, axis):
+ axis_len = array.shape[axis]
+ return [
+ np.squeeze(arr, axis=(axis,))
+ for arr in np.split(array, axis_len, axis=axis)
+ ]
+
+
class PackOpTest(tf.test.TestCase):
def testSimple(self):
@@ -61,7 +69,7 @@ class PackOpTest(tf.test.TestCase):
b = tf.reshape(a, tf.pack([2, 3]))
self.assertAllEqual(b.get_shape(), [2, 3])
- def testGradients(self):
+ def testGradientsAxis0(self):
np.random.seed(7)
for use_gpu in False, True:
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
@@ -74,6 +82,21 @@ class PackOpTest(tf.test.TestCase):
err = tf.test.compute_gradient_error(xs, shapes, c, shape)
self.assertLess(err, 1e-6)
+ def testGradientsAxis1(self):
+ np.random.seed(7)
+ for use_gpu in False, True:
+ for shape in (2, 3), (3, 2), (4, 3, 2):
+ data = np.random.randn(*shape)
+ shapes = [shape[1:]] * shape[0]
+ out_shape = list(shape[1:])
+ out_shape.insert(1, shape[0])
+ with self.test_session(use_gpu=use_gpu):
+ # TODO(irving): Remove list() once we handle maps correctly
+ xs = list(map(tf.constant, data))
+ c = tf.pack(xs, axis=1)
+ err = tf.test.compute_gradient_error(xs, shapes, c, out_shape)
+ self.assertLess(err, 1e-6)
+
def testZeroSize(self):
# Verify that pack doesn't crash for zero size inputs
for use_gpu in False, True:
@@ -83,6 +106,38 @@ class PackOpTest(tf.test.TestCase):
p = tf.pack(list(x)).eval()
self.assertAllEqual(p, x)
+ def testAxis0Default(self):
+ with self.test_session():
+ t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])]
+
+ packed = tf.pack(t).eval()
+
+ self.assertAllEqual(packed, np.array([[1, 2, 3], [4, 5, 6]]))
+
+ def testAgainstNumpy(self):
+ # For 1 to 5 dimensions.
+ for i in range(1, 6):
+ expected = np.random.random(np.random.permutation(i) + 1)
+
+ # For all the possible axis to split it, including negative indices.
+ for j in range(-i, i):
+ test_arrays = np_split_sqeeze(expected, j)
+
+ with self.test_session():
+ actual = tf.pack(test_arrays, axis=j).eval()
+
+ self.assertNDArrayNear(expected, actual, 1e-6)
+
+ def testDimOutOfRange(self):
+ t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])]
+ with self.assertRaisesRegexp(ValueError, r"axis = 2 not in \[-2, 2\)"):
+ tf.unpack(t, axis=2)
+
+ def testDimOutOfNegativeRange(self):
+ t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])]
+ with self.assertRaisesRegexp(ValueError, r"axis = -3 not in \[-2, 2\)"):
+ tf.unpack(t, axis=-3)
+
class AutomaticPackingTest(tf.test.TestCase):
diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
index 8e9449587a..ad2a52cd43 100644
--- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
@@ -276,6 +276,16 @@ class PaddingFIFOQueueTest(tf.test.TestCase):
enqueue_op.run()
self.assertEqual([], dequeued_t.eval().tolist())
+ def testEmptyDequeueUpToWithDynamicShape(self):
+ with self.test_session():
+ q = tf.PaddingFIFOQueue(10, tf.float32, shapes=((None,),))
+ enqueue_op = q.enqueue(([10.0],))
+ dequeued_t = q.dequeue_up_to(0)
+
+ self.assertEqual([], dequeued_t.eval().tolist())
+ enqueue_op.run()
+ self.assertEqual([], dequeued_t.eval().tolist())
+
def testConstructPaddingFIFOQueueWithNoShape(self):
with self.test_session():
with self.assertRaisesRegexp(
@@ -328,6 +338,18 @@ class PaddingFIFOQueueTest(tf.test.TestCase):
self.assertAllEqual(elems[0:4], dequeued_t.eval())
self.assertAllEqual(elems[4:8], dequeued_t.eval())
+ def testDequeueUpToNoBlocking(self):
+ with self.test_session():
+ q = tf.PaddingFIFOQueue(10, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_up_to(4)
+
+ enqueue_op.run()
+
+ self.assertAllEqual(elems[0:4], dequeued_t.eval())
+ self.assertAllEqual(elems[4:8], dequeued_t.eval())
+
def testMultiDequeueMany(self):
with self.test_session() as sess:
q = tf.PaddingFIFOQueue(10, (tf.float32, tf.int32),
@@ -451,6 +473,62 @@ class PaddingFIFOQueueTest(tf.test.TestCase):
tf.TensorShape(int_val.shape).is_compatible_with(
dequeued_single_t[1].get_shape()))
+ def testMultiDequeueUpToPartiallyKnownShapesAndVariableInputNoBlocking(self):
+ with self.test_session() as sess:
+ q = tf.PaddingFIFOQueue(10, (tf.string, tf.int32),
+ shapes=((None,), (1, None)))
+ str_elems = [
+ ["a"],
+ ["ab"],
+ ["abc"],
+ ["abc", "d"],
+ ["abc", "d", "e"],
+ ["abc", "d", "e", "f"]]
+
+ int_elems = [
+ [[1]],
+ [[2]],
+ [[3]],
+ [[1, 2]],
+ [[1, 2, 3]],
+ [[1, 2, 3, 4]]]
+
+ enqueue_ops = [q.enqueue((str_elems[i], int_elems[i])) for i in range(6)]
+
+ dequeued_t = q.dequeue_up_to(5)
+ dequeued_single_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+ string_val, int_val = sess.run(dequeued_t)
+
+ self.assertAllEqual(
+ [[b"a", b"", b""], [b"ab", b"", b""], [b"abc", b"", b""],
+ [b"abc", b"d", b""], [b"abc", b"d", b"e"]], string_val)
+ self.assertAllEqual(
+ [[[1, 0, 0]],
+ [[2, 0, 0]],
+ [[3, 0, 0]],
+ [[1, 2, 0]],
+ [[1, 2, 3]]],
+ int_val)
+ self.assertTrue(
+ tf.TensorShape(string_val.shape).is_compatible_with(
+ dequeued_t[0].get_shape()))
+ self.assertTrue(
+ tf.TensorShape(int_val.shape).is_compatible_with(
+ dequeued_t[1].get_shape()))
+
+ string_val, int_val = sess.run(dequeued_single_t)
+ self.assertAllEqual([b"abc", b"d", b"e", b"f"], string_val)
+ self.assertAllEqual([[1, 2, 3, 4]], int_val)
+ self.assertTrue(
+ tf.TensorShape(string_val.shape).is_compatible_with(
+ dequeued_single_t[0].get_shape()))
+ self.assertTrue(
+ tf.TensorShape(int_val.shape).is_compatible_with(
+ dequeued_single_t[1].get_shape()))
+
def testHighDimension(self):
with self.test_session():
q = tf.PaddingFIFOQueue(10, tf.int32, ((4, 4, 4, 4),))
@@ -576,6 +654,29 @@ class PaddingFIFOQueueTest(tf.test.TestCase):
thread.join()
self.assertItemsEqual(elems, dequeued_elems)
+ def testParallelDequeueUpTo(self):
+ with self.test_session() as sess:
+ q = tf.PaddingFIFOQueue(1000, tf.float32, shapes=((),))
+ elems = [10.0 * x for x in range(1000)]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue_up_to(101)
+
+ enqueue_op.run()
+ close_op.run()
+
+ # Dequeue up to 101 items in parallel on 10 threads, from closed queue.
+ dequeued_elems = []
+
+ def dequeue():
+ dequeued_elems.extend(sess.run(dequeued_t))
+ threads = [self.checkedThread(target=dequeue) for _ in range(10)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ self.assertItemsEqual(elems, dequeued_elems)
+
def testParallelEnqueueAndDequeue(self):
with self.test_session() as sess:
q = tf.PaddingFIFOQueue(50, tf.float32, shapes=((),))
@@ -703,6 +804,33 @@ class PaddingFIFOQueueTest(tf.test.TestCase):
self.assertAllEqual(elems, dequeued_elems)
+ def testBlockingDequeueUpTo(self):
+ with self.test_session() as sess:
+ q = tf.PaddingFIFOQueue(10, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_up_to(4)
+
+ dequeued_elems = []
+
+ def enqueue():
+ # The enqueue_op should run after the dequeue op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ sess.run(enqueue_op)
+
+ def dequeue():
+ dequeued_elems.extend(sess.run(dequeued_t).tolist())
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ enqueue_thread.start()
+ dequeue_thread.start()
+ enqueue_thread.join()
+ dequeue_thread.join()
+
+ self.assertAllEqual(elems, dequeued_elems)
+
def testDequeueManyWithTensorParameter(self):
with self.test_session():
# Define a first queue that contains integer counts.
@@ -772,6 +900,28 @@ class PaddingFIFOQueueTest(tf.test.TestCase):
close_op.run()
dequeue_thread.join()
+ def testDequeueUpToFromClosedQueueReturnsRemainder(self):
+ with self.test_session() as sess:
+ q = tf.PaddingFIFOQueue(10, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue_up_to(3)
+
+ enqueue_op.run()
+
+ def dequeue():
+ self.assertAllEqual(elems[:3], sess.run(dequeued_t))
+ self.assertAllEqual(elems[3:], sess.run(dequeued_t))
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
def testBlockingDequeueFromClosedEmptyQueue(self):
with self.test_session() as sess:
q = tf.PaddingFIFOQueue(10, tf.float32, ((),))
@@ -817,6 +967,31 @@ class PaddingFIFOQueueTest(tf.test.TestCase):
close_op.run()
dequeue_thread.join()
+ def testBlockingDequeueManyButNotAllFromClosedQueue(self):
+ with self.test_session() as sess:
+ q = tf.PaddingFIFOQueue(10, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue_many(3)
+
+ enqueue_op.run()
+
+ def dequeue():
+ self.assertAllEqual(elems[:3], sess.run(dequeued_t))
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
with self.test_session() as sess:
q = tf.PaddingFIFOQueue(4, tf.float32, ((),))
@@ -906,6 +1081,26 @@ class PaddingFIFOQueueTest(tf.test.TestCase):
close_op.run()
dequeue_thread.join()
+ def testBlockingDequeueUpToFromClosedEmptyQueue(self):
+ with self.test_session() as sess:
+ q = tf.PaddingFIFOQueue(10, tf.float32, ((),))
+ close_op = q.close()
+ dequeued_t = q.dequeue_up_to(4)
+
+ def dequeue():
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
def testEnqueueToClosedQueue(self):
with self.test_session():
q = tf.PaddingFIFOQueue(10, tf.float32, ((),))
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
index 0487621b46..a68f722244 100644
--- a/tensorflow/python/kernel_tests/reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -99,11 +99,6 @@ class ReshapeTest(tf.test.TestCase):
self._testBothReshape(x, [1, -1, 5])
def testErrors(self):
- x = tf.constant(0.0, shape=[1, 0, 3])
- with self.assertRaisesRegexp(
- ValueError, "cannot infer the missing input size"):
- tf.reshape(x, [0, -1, 5])
-
y = tf.constant(0.0, shape=[23, 29, 31])
with self.assertRaisesRegexp(ValueError, "isn't divisible by 17"):
tf.reshape(y, [17, -1])
@@ -128,6 +123,20 @@ class ReshapeTest(tf.test.TestCase):
y = tf.reshape(x, tf.placeholder(tf.int32, shape=(3,)))
self.assertEqual([None, None, None], y.get_shape().as_list())
+ # Unknown input shape, partial new shape using `tf.pack()`.
+ y = tf.reshape(x, [tf.placeholder(tf.int32), 37])
+ self.assertEqual([None, 37], y.get_shape().as_list())
+
+ # Unknown input shape, partial new shape using `tf.concat()`.
+ y = tf.reshape(x, tf.concat(0, [tf.placeholder(tf.int32, shape=(2,)),
+ [37, 42]]))
+ self.assertEqual([None, None, 37, 42], y.get_shape().as_list())
+
+ # Unknown input shape, partial new shape using `tf.shape()`.
+ y = tf.reshape(x, tf.shape(tf.placeholder(tf.float32,
+ shape=[None, 37, None])))
+ self.assertEqual([None, 37, None], y.get_shape().as_list())
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 34b1c81e77..707133295c 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import itertools
import time
import timeit
@@ -1157,6 +1158,114 @@ class BidirectionalRNNTest(tf.test.TestCase):
self._testBidirectionalRNNWithoutSequenceLength(use_gpu=True,
use_shape=True)
+ def _createBidirectionalDynamicRNN(self, use_gpu, use_shape,
+ use_state_tuple, use_time_major):
+ num_units = 3
+ input_size = 5
+ batch_size = 2
+ max_length = 8
+
+ initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+ sequence_length = tf.placeholder(tf.int64)
+ cell_fw = tf.nn.rnn_cell.LSTMCell(num_units,
+ initializer=initializer,
+ state_is_tuple=use_state_tuple)
+ cell_bw = tf.nn.rnn_cell.LSTMCell(num_units,
+ initializer=initializer,
+ state_is_tuple=use_state_tuple)
+ inputs = max_length * [
+ tf.placeholder(tf.float32,
+ shape=(batch_size if use_shape else None, input_size))]
+ inputs_c = tf.pack(inputs)
+ if not use_time_major:
+ inputs_c = tf.transpose(inputs_c, [1, 0, 2])
+ outputs, states = tf.nn.bidirectional_dynamic_rnn(
+ cell_fw,
+ cell_bw,
+ inputs_c,
+ sequence_length,
+ dtype=tf.float32,
+ time_major=use_time_major)
+ outputs = tf.concat(2, outputs)
+ state_fw, state_bw = states
+ outputs_shape = [None, max_length, 2 * num_units]
+ if use_shape:
+ outputs_shape[0] = batch_size
+ if use_time_major:
+ outputs_shape[0], outputs_shape[1] = outputs_shape[1], outputs_shape[0]
+ self.assertEqual(
+ outputs.get_shape().as_list(),
+ outputs_shape)
+
+ input_value = np.random.randn(batch_size, input_size)
+
+ return input_value, inputs, outputs, state_fw, state_bw, sequence_length
+
+ def _testBidirectionalDynamicRNN(self, use_gpu, use_shape,
+ use_state_tuple, use_time_major):
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ input_value, inputs, outputs, state_fw, state_bw, sequence_length = (
+ self._createBidirectionalDynamicRNN(
+ use_gpu, use_shape, use_state_tuple, use_time_major))
+ tf.initialize_all_variables().run()
+ # Run with pre-specified sequence length of 2, 3
+ if use_state_tuple:
+ out, c_fw, m_fw, c_bw, m_bw = sess.run(
+ [outputs, state_fw[0], state_fw[1], state_bw[0], state_bw[1]],
+ feed_dict={inputs[0]: input_value,
+ sequence_length: [2, 3]})
+ s_fw = (c_fw, m_fw)
+ s_bw = (c_bw, m_bw)
+ else:
+ out, s_fw, s_bw = sess.run([outputs, state_fw, state_bw],
+ feed_dict={inputs[0]: input_value,
+ sequence_length: [2, 3]})
+
+ # Since the forward and backward LSTM cells were initialized with the
+ # same parameters, the forward and backward output has to be the same,
+ # but reversed in time. The format is output[time][batch][depth], and
+ # due to depth concatenation (as num_units=3 for both RNNs):
+ # - forward output: out[][][depth] for 0 <= depth < 3
+ # - backward output: out[][][depth] for 4 <= depth < 6
+ #
+ # First sequence in batch is length=2
+ # Check that the time=0 forward output is equal to time=1 backward output
+ if not use_time_major:
+ out = np.swapaxes(out, 0, 1)
+ self.assertEqual(out[0][0][0], out[1][0][3])
+ self.assertEqual(out[0][0][1], out[1][0][4])
+ self.assertEqual(out[0][0][2], out[1][0][5])
+ # Check that the time=1 forward output is equal to time=0 backward output
+ self.assertEqual(out[1][0][0], out[0][0][3])
+ self.assertEqual(out[1][0][1], out[0][0][4])
+ self.assertEqual(out[1][0][2], out[0][0][5])
+
+ # Second sequence in batch is length=3
+ # Check that the time=0 forward output is equal to time=2 backward output
+ self.assertEqual(out[0][1][0], out[2][1][3])
+ self.assertEqual(out[0][1][1], out[2][1][4])
+ self.assertEqual(out[0][1][2], out[2][1][5])
+ # Check that the time=1 forward output is equal to time=1 backward output
+ self.assertEqual(out[1][1][0], out[1][1][3])
+ self.assertEqual(out[1][1][1], out[1][1][4])
+ self.assertEqual(out[1][1][2], out[1][1][5])
+ # Check that the time=2 forward output is equal to time=0 backward output
+ self.assertEqual(out[2][1][0], out[0][1][3])
+ self.assertEqual(out[2][1][1], out[0][1][4])
+ self.assertEqual(out[2][1][2], out[0][1][5])
+ # Via the reasoning above, the forward and backward final state should be
+ # exactly the same
+ self.assertAllClose(s_fw, s_bw)
+
+ def testBidirectionalDynamicRNN(self):
+ # Generate 2^4 option values
+ # from [True, True, True, True] to [False, False, False, False]
+ options = itertools.product([True, False], repeat=4)
+ for option in options:
+ self._testBidirectionalDynamicRNN(use_gpu=option[0], use_shape=option[1],
+ use_state_tuple=option[2],
+ use_time_major=option[3])
+
class MultiDimensionalLSTMTest(tf.test.TestCase):
diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py
index 7cc6d31efb..0cb701db82 100644
--- a/tensorflow/python/kernel_tests/unpack_op_test.py
+++ b/tensorflow/python/kernel_tests/unpack_op_test.py
@@ -23,6 +23,14 @@ from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+def np_split_sqeeze(array, axis):
+ axis_len = array.shape[axis]
+ return [
+ np.squeeze(arr, axis=(axis,))
+ for arr in np.split(array, axis_len, axis=axis)
+ ]
+
+
class UnpackOpTest(tf.test.TestCase):
def testSimple(self):
@@ -40,7 +48,7 @@ class UnpackOpTest(tf.test.TestCase):
cs = [c.eval() for c in cs]
self.assertAllEqual(cs, data)
- def testGradients(self):
+ def testGradientsAxis0(self):
for use_gpu in False, True:
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
data = np.random.randn(*shape)
@@ -52,6 +60,19 @@ class UnpackOpTest(tf.test.TestCase):
err = tf.test.compute_gradient_error(x, shape, cs[i], shapes[i])
self.assertLess(err, 1e-6)
+ def testGradientsAxis1(self):
+ for use_gpu in False, True:
+ for shape in (2, 3), (3, 2), (4, 3, 2):
+ data = np.random.randn(*shape)
+ out_shape = list(shape)
+ del out_shape[1]
+ for i in xrange(shape[1]):
+ with self.test_session(use_gpu=use_gpu):
+ x = tf.constant(data)
+ cs = tf.unpack(x, num=shape[1], axis=1)
+ err = tf.test.compute_gradient_error(x, shape, cs[i], out_shape)
+ self.assertLess(err, 1e-6)
+
def testInferNum(self):
with self.test_session():
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
@@ -60,12 +81,55 @@ class UnpackOpTest(tf.test.TestCase):
self.assertEqual(type(cs), list)
self.assertEqual(len(cs), shape[0])
- def testCannotInferNum(self):
+ def testCannotInferNumFromUnknownShape(self):
x = tf.placeholder(np.float32)
with self.assertRaisesRegexp(
ValueError, r'Cannot infer num from shape <unknown>'):
tf.unpack(x)
+ def testUnknownShapeOkWithNum(self):
+ x = tf.placeholder(np.float32)
+ tf.unpack(x, num=2)
+
+ def testCannotInferNumFromNoneShape(self):
+ x = tf.placeholder(np.float32, shape=(None,))
+ with self.assertRaisesRegexp(ValueError,
+ r'Cannot infer num from shape \(\?,\)'):
+ tf.unpack(x)
+
+ def testAgainstNumpy(self):
+ # For 1 to 5 dimensions.
+ for i in range(1, 6):
+ a = np.random.random(np.random.permutation(i) + 1)
+
+ # For all the possible axis to split it, including negative indices.
+ for j in range(-i, i):
+ expected = np_split_sqeeze(a, j)
+
+ with self.test_session() as sess:
+ actual = sess.run(tf.unpack(a, axis=j))
+
+ self.assertAllEqual(expected, actual)
+
+ def testAxis0Default(self):
+ with self.test_session() as sess:
+ a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a')
+
+ unpacked = sess.run(tf.unpack(a))
+
+ self.assertEqual(len(unpacked), 2)
+ self.assertAllEqual(unpacked[0], [1, 2, 3])
+ self.assertAllEqual(unpacked[1], [4, 5, 6])
+
+ def testAxisOutOfRange(self):
+ a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a')
+ with self.assertRaisesRegexp(ValueError, r'axis = 2 not in \[-2, 2\)'):
+ tf.unpack(a, axis=2)
+
+ def testAxisOutOfNegativeRange(self):
+ a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a')
+ with self.assertRaisesRegexp(ValueError, r'axis = -3 not in \[-2, 2\)'):
+ tf.unpack(a, axis=-3)
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 5b70f94723..82f010cf5b 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -25,6 +25,7 @@ import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
class VariablesTestCase(tf.test.TestCase):
@@ -418,5 +419,26 @@ class ObsoleteIsInitializedTest(tf.test.TestCase):
inited.op.run()
+class VariableContainerTest(tf.test.TestCase):
+
+ def testContainer(self):
+ with tf.Graph().as_default():
+ v0 = tf.Variable([0])
+ with tf.container("l1"):
+ v1 = tf.Variable([1])
+ with tf.container("l2"):
+ v2 = tf.Variable([2])
+ special_v = state_ops.variable_op([1], tf.float32, container="l3")
+ v3 = tf.Variable([3])
+ v4 = tf.Variable([4])
+ self.assertEqual(tf.compat.as_bytes(""), v0.op.get_attr("container"))
+ self.assertEqual(tf.compat.as_bytes("l1"), v1.op.get_attr("container"))
+ self.assertEqual(tf.compat.as_bytes("l2"), v2.op.get_attr("container"))
+ self.assertEqual(tf.compat.as_bytes("l3"),
+ special_v.op.get_attr("container"))
+ self.assertEqual(tf.compat.as_bytes("l1"), v3.op.get_attr("container"))
+ self.assertEqual(tf.compat.as_bytes(""), v4.op.get_attr("container"))
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 0e85aaf80f..7ccef03b1a 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -30,13 +30,13 @@ from tensorflow.python.ops import math_ops
@ops.RegisterGradient("Pack")
def _PackGrad(op, grad):
"""Gradient for pack op."""
- return array_ops.unpack(grad, num=op.get_attr("N"))
+ return array_ops.unpack(grad, num=op.get_attr("N"), axis=op.get_attr("axis"))
@ops.RegisterGradient("Unpack")
-def _UnpackGrad(_, *grads):
+def _UnpackGrad(op, *grads):
"""Gradient for unpack op."""
- return array_ops.pack(grads)
+ return array_ops.pack(grads, axis=op.get_attr("axis"))
@ops.RegisterGradient("Concat")
@@ -149,6 +149,27 @@ def _SliceGrad(op, grad):
return array_ops.pad(grad, paddings), None, None
+@ops.RegisterGradient("StridedSlice")
+def _StridedSliceGrad(op, grad):
+ """Gradient for unpack op."""
+ x = array_ops.shape(op.inputs[0])
+ begin = op.inputs[1]
+ end = op.inputs[2]
+ strides = op.inputs[3]
+
+ return array_ops.strided_slice_grad(
+ x,
+ begin,
+ end,
+ strides,
+ grad,
+ begin_mask=op.get_attr("begin_mask"),
+ end_mask=op.get_attr("end_mask"),
+ ellipse_mask=op.get_attr("ellipse_mask"),
+ new_axis_mask=op.get_attr("new_axis_mask"),
+ shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None
+
+
@ops.RegisterGradient("Split")
def _SplitGrad(op, *grads):
return None, array_ops.concat(op.inputs[0], list(grads))
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 70b4c5bd35..509f627170 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -495,7 +495,7 @@ def strided_slice(input_,
ops.Tensor._override_operator("__getitem__", _SliceHelper)
-def pack(values, name="pack"):
+def pack(values, axis=0, name="pack"):
"""Packs a list of rank-`R` tensors into one rank-`(R+1)` tensor.
Packs tensors in `values` into a tensor with rank one higher than each tensor
@@ -508,17 +508,31 @@ def pack(values, name="pack"):
Args:
values: A list of `Tensor` objects with the same shape and type.
+ axis: An `int`. The axis to pack along. Defaults to the first dimension.
+ Supports negative indexes.
name: A name for this operation (optional).
Returns:
output: A packed `Tensor` with the same type as `values`.
+
+ Raises:
+ ValueError: If `axis` is out of the range [-(R+1), R+1).
"""
- try:
- # If the input is a constant list, it can just be converted to a constant op
- return ops.convert_to_tensor(values, name=name)
- except (TypeError, ValueError):
- # Input list contains non-constant tensors
- return gen_array_ops._pack(values, name=name)
+ if axis == 0:
+ try:
+ # If the input is a constant list, it can be converted to a constant op
+ return ops.convert_to_tensor(values, name=name)
+ except (TypeError, ValueError):
+ pass # Input list contains non-constant tensors
+
+ value_shape = ops.convert_to_tensor(values[0], name=name).get_shape()
+ if value_shape.ndims is not None:
+ expanded_num_dims = value_shape.ndims + 1
+ if axis < -expanded_num_dims or axis >= expanded_num_dims:
+ raise ValueError("axis = %d not in [%d, %d)" %
+ (axis, -expanded_num_dims, expanded_num_dims))
+
+ return gen_array_ops._pack(values, axis=axis, name=name)
# pylint: disable=invalid-name
@@ -609,12 +623,12 @@ ops.register_tensor_conversion_function(
(list, tuple), _autopacking_conversion_function, 99)
-def unpack(value, num=None, name="unpack"):
- """Unpacks the outer dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
+def unpack(value, num=None, axis=0, name="unpack"):
+ """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
- Unpacks `num` tensors from `value` along the first dimension.
+ Unpacks `num` tensors from `value` along the given dimension.
If `num` is not specified (the default), it is inferred from `value`'s shape.
- If `value.shape[0]` is not known, `ValueError` is raised.
+ If `value.shape[axis]` is not known, `ValueError` is raised.
The ith tensor in `output` is the slice `value[i, ...]`. Each tensor in
`output` has shape `value.shape[1:]`.
@@ -625,8 +639,10 @@ def unpack(value, num=None, name="unpack"):
Args:
value: A rank `R > 0` `Tensor` to be unpacked.
- num: An `int`. The first dimension of value. Automatically inferred if
- `None` (the default).
+ num: An `int`. The length of the dimension `axis`. Automatically inferred
+ if `None` (the default).
+ axis: An `int`. The axis to unpack along. Defaults to the first
+ dimension. Supports negative indexes.
name: A name for the operation (optional).
Returns:
@@ -634,14 +650,19 @@ def unpack(value, num=None, name="unpack"):
Raises:
ValueError: If `num` is unspecified and cannot be inferred.
+ ValueError: If `axis` is out of the range [-R, R).
"""
if num is None:
value = ops.convert_to_tensor(value)
- shape = value.get_shape()
- num = shape[0].value
- if num is None:
- raise ValueError("Cannot infer num from shape %s" % shape)
- return gen_array_ops._unpack(value, num=num, name=name)
+ value_shape = value.get_shape()
+ if value_shape.ndims is not None:
+ if axis < -value_shape.ndims or axis >= value_shape.ndims:
+ raise ValueError("axis = %d not in [%d, %d)" %
+ (axis, -value_shape.ndims, value_shape.ndims))
+ num = value_shape[axis].value
+ if num is None:
+ raise ValueError("Cannot infer num from shape %s" % value_shape)
+ return gen_array_ops._unpack(value, num=num, axis=axis, name=name)
def concat(concat_dim, values, name="concat"):
@@ -707,15 +728,26 @@ def concat(concat_dim, values, name="concat"):
@ops.RegisterShape("Pack")
def _PackShape(op):
input_shape = op.inputs[0].get_shape()
+ if input_shape.ndims is None:
+ return [tensor_shape.unknown_shape()]
+
for inp in op.inputs[1:]:
input_shape = input_shape.merge_with(inp.get_shape())
- return [tensor_shape.TensorShape([len(op.inputs)]).concatenate(input_shape)]
+
+ input_shape = input_shape.as_list()
+ input_shape.insert(op.get_attr("axis"), len(op.inputs))
+ return [tensor_shape.TensorShape(input_shape)]
@ops.RegisterShape("Unpack")
def _UnpackShape(op):
input_shape = op.inputs[0].get_shape()
- return [input_shape[1:]] * op.get_attr("num")
+ if input_shape.ndims is None:
+ return [tensor_shape.unknown_shape()] * op.get_attr("num")
+
+ input_shape = input_shape.as_list()
+ del input_shape[op.get_attr("axis")]
+ return [tensor_shape.TensorShape(input_shape)] * op.get_attr("num")
@ops.RegisterShape("Concat")
@@ -1437,6 +1469,12 @@ def _compute_size_of_strided_dim(spec, size):
return unknown # unknown because stride is unknown
+@ops.RegisterShape("StridedSliceGrad")
+def _StridedSliceGradShape(op):
+ """Shape function for gradient of array_ops.slice."""
+ return [tensor_util.constant_value(op.inputs[0])]
+
+
@ops.RegisterShape("StridedSlice")
def _StridedSliceShape(op):
"""Shape function for array_ops.slice."""
@@ -1742,45 +1780,38 @@ def _ReshapeShape(op):
num_elements *= dim
else:
num_elements = tensor_shape.Dimension(None)
- new_shape_shape = op.inputs[1].get_shape().with_rank(1)
- new_shape = tensor_util.constant_value(op.inputs[1])
- if new_shape is None:
- # Attempt to infer the rank of the output from the length of
- # new_shape.
- return [tensor_shape.unknown_shape(ndims=new_shape_shape[0].value)]
- new_shape = np.reshape(new_shape, -1).tolist()
- if -1 not in new_shape:
+ new_shape = tensor_util.constant_value_as_shape(op.inputs[1])
+ if new_shape.ndims is None:
+ # We have no information about the shape of the output.
+ return [new_shape]
+ if None not in new_shape.as_list():
# The new shape is fully defined.
if (num_elements.value is not None
and num_elements.value != np.prod(new_shape)):
raise ValueError(
"Cannot reshape a tensor with %d elements to shape %s (%d elements)"
% (num_elements.value, new_shape, np.prod(new_shape)))
- return [tensor_shape.TensorShape(new_shape)]
elif num_elements.value is not None:
# We know the number of elements, so we can calculate the missing
# dimension in the new_shape.
known_elements = 1
- unknown_index = None
+ unknown_indices = []
for i, dim in enumerate(new_shape):
- if dim == -1:
- unknown_index = i
+ if dim.value is None:
+ unknown_indices.append(i)
else:
- known_elements *= dim
- if known_elements == 0:
- raise ValueError("cannot infer the missing input size for "
- "an empty tensor unless all specified "
- "input sizes are non-zero")
- if num_elements % known_elements != 0:
- raise ValueError("input has %s elements, which isn't divisible by %d" %
- (num_elements, known_elements))
- new_shape[unknown_index] = num_elements // known_elements
- return [tensor_shape.TensorShape(new_shape)]
- else:
- # We don't know the input shape, but we know n-1 of the dimensions
- # in the new shape.
- new_shape[new_shape.index(-1)] = None
- return [tensor_shape.TensorShape(new_shape)]
+ known_elements *= dim.value
+ if known_elements != 0:
+ if num_elements % known_elements != 0:
+ raise ValueError("input has %s elements, which isn't divisible by %d" %
+ (num_elements, known_elements))
+ if len(unknown_indices) == 1:
+ unknown_index = unknown_indices[0]
+ new_shape = new_shape.merge_with(
+ new_shape[:unknown_index].concatenate(
+ [num_elements // known_elements]).concatenate(
+ new_shape[unknown_index+1:]))
+ return [new_shape]
@ops.RegisterShape("BroadcastGradientArgs")
diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py
index 24d4f4a9a2..a977bb57e0 100644
--- a/tensorflow/python/ops/control_flow_grad.py
+++ b/tensorflow/python/ops/control_flow_grad.py
@@ -199,7 +199,7 @@ def _EnterGrad(op, grad):
if op.get_attr("is_constant"):
# Add a gradient accumulator for each loop invariant.
if isinstance(grad, ops.Tensor):
- result = grad_ctxt.AddBackPropAccumulator(grad)
+ result = grad_ctxt.AddBackPropAccumulator(op, grad)
elif isinstance(grad, ops.IndexedSlices):
result = grad_ctxt.AddBackPropIndexedSlicesAccumulator(grad)
else:
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 3063675e40..8c02399f81 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -773,13 +773,13 @@ class GradLoopState(object):
# Record the history of this value in forward_ctxt.
# TODO(yuanbyu): Avoid recording constants.
self._grad_context.Exit()
- h_value = cur_grad_state.AddForwardAccumulator(cur_value)
+ history_value = cur_grad_state.AddForwardAccumulator(cur_value)
self._grad_context.Enter()
break
if real_value is None:
# Add the stack pop op in the grad context.
- real_value = self.AddBackPropAccumulatedValue(h_value, value)
+ real_value = self.AddBackPropAccumulatedValue(history_value, value)
self._history_map[value.name] = real_value
return real_value
@@ -966,13 +966,13 @@ class ControlFlowState(object):
# Add forward accumulator for shape.
grad_state.grad_context.Exit()
- h_shape = grad_state.AddForwardAccumulator(
+ history_zeros_shape = grad_state.AddForwardAccumulator(
zeros_shape, dead_branch=dead_branch)
grad_state.grad_context.Enter()
# Create a zero tensor with the right shape.
shape = grad_state.AddBackPropAccumulatedValue(
- h_shape, zeros_shape, dead_branch)
+ history_zeros_shape, zeros_shape, dead_branch)
result = array_ops.zeros(shape, val.dtype)
return result
@@ -1596,36 +1596,56 @@ class WhileContext(ControlFlowContext):
self.Exit()
return next_count
- def AddBackPropAccumulator(self, value):
+ def AddBackPropAccumulator(self, op, grad):
"""Add an accumulation loop for every loop invariant.
- This is added to the backprop loop. It is used to accumulate
- partial gradients within each loop iteration. Called when in the
- gradient while context.
+ This is added to the backprop loop. It is used to accumulate partial
+ gradients within each loop iteration. Called when in the gradient while
+ context.
The pseudocode is:
```
acc = 0.0;
while (_pivot) {
- acc += value;
+ acc += grad;
}
```
Args:
- value: The partial gradient of an iteration for a loop invariant.
+ op: The Enter op for a loop invariant.
+ grad: The partial gradient of an iteration for a loop invariant.
Returns:
The gradient for a loop invariant.
"""
self.Exit()
- shape = value.get_shape()
- if not shape.is_fully_defined():
- shape = None
- if self.outer_context: self.outer_context.Enter()
- acc = constant_op.constant(0, value.dtype, shape=shape, name="b_acc")
- if not shape:
- acc._shape = value.get_shape() # pylint: disable=protected-access
- if self.outer_context: self.outer_context.Exit()
+ # Create a zeros tensor with the right shape for acc. If we don't
+ # know the full shape statically, we will have to get the shape
+ # dynamically from the forward inference. Getting the shape right
+ # for the zeros is only needed for the base case when the loop exits
+ # without running any iterations.
+ shape = grad.get_shape()
+ if shape.is_fully_defined():
+ if self.outer_context: self.outer_context.Enter()
+ acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
+ if self.outer_context: self.outer_context.Exit()
+ else:
+ value = op.inputs[0]
+ if self.outer_context:
+ forward_ctxt = self.grad_state.forward_ctxt
+ forward_ctxt.outer_context.Enter()
+ zeros_shape = array_ops.shape(value)
+ forward_ctxt.outer_context.Exit()
+ history_zeros_shape = grad_state.AddForwardAccumulator(zeros_shape)
+ self.outer_context.Enter()
+ real_shape = outer_grad_state.AddBackPropAccumulatedValue(
+ history_zeros_shape, zeros_shape)
+ acc = array_ops.zeros(real_shape, grad.dtype)
+ self.outer_context.Exit()
+ else:
+ zeros_shape = array_ops.shape(value)
+ acc = array_ops.zeros(zeros_shape, grad.dtype)
+ acc._shape = grad.get_shape() # pylint: disable=protected-access
self.Enter()
self.AddName(acc.name)
@@ -1633,30 +1653,30 @@ class WhileContext(ControlFlowContext):
parallel_iterations=self._parallel_iterations,
name="b_acc")
merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
- switch_acc = switch(merge_acc, self._pivot)
+ switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
- add_acc = math_ops.add(switch_acc[1], value)
+ add_acc = math_ops.add(switch_acc_true, grad)
next_acc = _NextIteration(add_acc)
merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access
- acc_result = exit(switch_acc[0], name="b_acc")
+ acc_result = exit(switch_acc_false, name="b_acc")
self.ExitResult([acc_result])
return acc_result
- def AddBackPropIndexedSlicesAccumulator(self, value):
+ def AddBackPropIndexedSlicesAccumulator(self, grad):
"""This is used for accumulating gradients that are IndexedSlices.
This is essentially the equavalent of AddBackPropAccumulator but optimized
for things like updating embeddings from within a while loop.
Args:
- value: The partial gradients represented as an IndexedSlices.
+ grad: The partial gradients represented as an IndexedSlices.
Returns:
The accumulated IndexedSlices gradient of the loop invariant.
"""
- values = value.values
- indices = value.indices
+ values = grad.values
+ indices = grad.indices
self.Exit()
shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
@@ -1670,6 +1690,7 @@ class WhileContext(ControlFlowContext):
values_acc._shape = shape # pylint: disable=protected-access
indices_acc = constant_op.constant([0], indices.dtype)
if self.outer_context: self.outer_context.Exit()
+
self.Enter()
self.AddName(values_acc.name)
self.AddName(indices_acc.name)
@@ -1687,10 +1708,10 @@ class WhileContext(ControlFlowContext):
for xm, xn in zip(merge_acc, next_acc):
xm.op._update_input(1, xn) # pylint: disable=protected-access
- acc_result = [exit(x[0], name="b_acc") for x in switch_acc]
- self.ExitResult(acc_result)
- return ops.IndexedSlices(values=acc_result[1], indices=acc_result[0],
- dense_shape=self.ExitResult(value.dense_shape))
+ acc_exits = [exit(x[0], name="b_acc") for x in switch_acc]
+ self.ExitResult(acc_exits)
+ return ops.IndexedSlices(values=acc_exits[1], indices=acc_exits[0],
+ dense_shape=grad.dense_shape)
def _InitializeValues(self, values):
self._values = set()
@@ -1882,8 +1903,8 @@ def while_loop(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
```python
ijk_0 = (tf.constant(0), (tf.constant(1), tf.constant(2)))
- c = lambda i, j, k: i < 10
- b = lambda i, j, k: (i, (j + k), (j - k))
+ c = lambda i, (j, k): i < 10
+ b = lambda i, (j, k): (i + 1, ((j + k), (j - k)))
ijk_final = tf.while_loop(c, b, ijk_0)
```
"""
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 00033d4bc7..a4a3c4e669 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -456,7 +456,7 @@ class QueueBase(object):
If the queue is closed and there are more than `0` but fewer than
`n` elements remaining, then instead of raising a
`tf.errors.OutOfRangeError` like [`dequeue_many`](#QueueBase.dequeue_many),
- the remaining elements are returned immediately. If the queue is
+ less than `n` elements are returned immediately. If the queue is
closed and there are `0` elements left in the queue, then a
`tf.errors.OutOfRangeError` is raised just like in `dequeue_many`.
Otherwise the behavior is identical to `dequeue_many`.
@@ -732,6 +732,7 @@ ops.NoGradient("HashTable")
ops.NoGradient("InitializeTable")
ops.NoGradient("InitializeTableFromTextFile")
ops.NoGradient("MutableHashTable")
+ops.NoGradient("MutableHashTableOfTensors")
ops.RegisterShape("QueueSize")(common_shapes.scalar_shape)
@@ -807,16 +808,13 @@ def _DynamicStitchShape(op):
def _LookupTableFindShape(op):
"""Shape function for data_flow_ops._lookup_table_find."""
op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
- shape_in = op.inputs[1].get_shape()
- return [shape_in]
+ return [tensor_shape.unknown_shape()]
@ops.RegisterShape("LookupTableInsert")
def _LookupTableInsertShape(op):
"""Shape function for data_flow_ops._lookup_table_insert."""
op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
- keys_shape = op.inputs[1].get_shape()
- op.inputs[2].get_shape().merge_with(keys_shape)
return []
@@ -827,8 +825,18 @@ def _LookupTableSizeShape(op):
return [tensor_shape.scalar()]
+@ops.RegisterShape("LookupTableExport")
+def _LookupTableExportShape(op):
+ """Shape function for data_flow_ops._lookup_table_export_values."""
+ op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
+ keys_shape = tensor_shape.vector(None)
+ values_shape = tensor_shape.unknown_shape()
+ return [keys_shape, values_shape]
+
+
@ops.RegisterShape("HashTable")
@ops.RegisterShape("MutableHashTable")
+@ops.RegisterShape("MutableHashTableOfTensors")
def _HashTableShape(_):
"""Shape function for data_flow_ops._hash_table."""
return [tensor_shape.scalar()]
diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py
index ceaf518bd4..55fe7b067c 100644
--- a/tensorflow/python/ops/image_grad.py
+++ b/tensorflow/python/ops/image_grad.py
@@ -22,6 +22,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_image_ops
@@ -88,3 +89,65 @@ def _ResizeShape(op):
def _ResizeBilinearGradShape(op):
"""Shape function for ResizeBilinearGrad."""
return [op.inputs[1].get_shape()]
+
+
+@ops.RegisterShape("CropAndResizeGradImage")
+def _CropAndResizeGradImageShape(op):
+ """Shape function for CropAndResizeGradImage."""
+ image_size = tensor_util.constant_value(op.inputs[3])
+ if image_size is not None:
+ batch = image_size[0]
+ height = image_size[1]
+ width = image_size[2]
+ depth = image_size[3]
+ else:
+ batch = None
+ height = None
+ width = None
+ depth = None
+ return [tensor_shape.TensorShape([batch, height, width, depth])]
+
+
+@ops.RegisterShape("CropAndResizeGradBoxes")
+def _CropAndResizeGradBoxesShape(op):
+ """Shape function for CropAndResizeGradBoxes."""
+ return [op.inputs[2].get_shape()]
+
+
+@ops.RegisterGradient("CropAndResize")
+def _CropAndResizeGrad(op, grad):
+ """The derivatives for crop_and_resize.
+
+ We back-propagate to the image only when the input image tensor has floating
+ point dtype but we always back-propagate to the input boxes tensor.
+
+ Args:
+ op: The CropAndResize op.
+ grad: The tensor representing the gradient w.r.t. the output.
+
+ Returns:
+ The gradients w.r.t. the input image, boxes, as well as the always-None
+ gradients w.r.t. box_ind and crop_size.
+ """
+ image = op.inputs[0]
+ if image.get_shape().is_fully_defined():
+ image_shape = image.get_shape().as_list()
+ else:
+ image_shape = array_ops.shape(image)
+
+ allowed_types = [dtypes.float16, dtypes.float32, dtypes.float64]
+ if op.inputs[0].dtype in allowed_types:
+ # pylint: disable=protected-access
+ grad0 = gen_image_ops.crop_and_resize_grad_image(grad,
+ op.inputs[1],
+ op.inputs[2],
+ image_shape,
+ T=op.get_attr("T"))
+ # pylint: enable=protected-access
+ else:
+ grad0 = None
+
+ grad1 = gen_image_ops.crop_and_resize_grad_boxes(grad, op.inputs[0],
+ op.inputs[1], op.inputs[2])
+
+ return [grad0, grad1, None, None]
diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py
index 9d9cdd4ed0..dab9619424 100644
--- a/tensorflow/python/ops/image_grad_test.py
+++ b/tensorflow/python/ops/image_grad_test.py
@@ -174,5 +174,134 @@ class ResizeBilinearOpTest(tf.test.TestCase):
grad = tf.gradients(input_tensor, [resize_out])
self.assertEqual([None], grad)
+
+class CropAndResizeOpTest(tf.test.TestCase):
+
+ def testShapeIsCorrectAfterOp(self):
+ batch = 2
+ image_height = 3
+ image_width = 4
+ crop_height = 4
+ crop_width = 5
+ depth = 2
+ num_boxes = 2
+
+ image_shape = [batch, image_height, image_width, depth]
+ crop_size = [crop_height, crop_width]
+ crops_shape = [num_boxes, crop_height, crop_width, depth]
+
+ image = np.arange(0, batch * image_height * image_width *
+ depth).reshape(image_shape).astype(np.float32)
+ boxes = np.array([[0, 0, 1, 1], [.1, .2, .7, .8]], dtype=np.float32)
+ box_ind = np.array([0, 1], dtype=np.int32)
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ crops = tf.image.crop_and_resize(
+ tf.constant(image, shape=image_shape),
+ tf.constant(boxes, shape=[num_boxes, 4]),
+ tf.constant(box_ind, shape=[num_boxes]),
+ tf.constant(crop_size, shape=[2]))
+ self.assertEqual(crops_shape, list(crops.get_shape()))
+ crops = sess.run(crops)
+ self.assertEqual(crops_shape, list(crops.shape))
+
+ def _randomUniformAvoidAnchors(self, low, high, anchors, radius, num_samples):
+ """Generate samples that are far enough from a set of anchor points.
+
+ We generate uniform samples in [low, high], then reject those that are less
+ than radius away from any point in anchors. We stop after we have accepted
+ num_samples samples.
+
+ Args:
+ low: The lower end of the interval.
+ high: The upper end of the interval.
+ anchors: A list of length num_crops with anchor points to avoid.
+ radius: Distance threshold for the samples from the anchors.
+ num_samples: How many samples to produce.
+
+ Returns:
+ samples: A list of length num_samples with the accepted samples.
+ """
+ self.assertTrue(low < high)
+ self.assertTrue(radius >= 0)
+ num_anchors = len(anchors)
+ # Make sure that at least half of the interval is not forbidden.
+ self.assertTrue(2 * radius * num_anchors < 0.5 * (high - low))
+ anchors = np.reshape(anchors, num_anchors)
+ samples = []
+ while len(samples) < num_samples:
+ sample = np.random.uniform(low, high)
+ if np.all(np.fabs(sample - anchors) > radius):
+ samples.append(sample)
+ return samples
+
+ def testGradRandomBoxes(self):
+ """Test that the gradient is correct for randomly generated boxes.
+
+ The mapping is piecewise differentiable with respect to the box coordinates.
+ The points where the function is not differentiable are those which are
+ mapped to image pixels, i.e., the normalized y coordinates in
+ np.linspace(0, 1, image_height) and normalized x coordinates in
+ np.linspace(0, 1, image_width). Make sure that the box coordinates are
+ sufficiently far away from those rectangular grid centers that are points of
+ discontinuity, so that the finite difference Jacobian is close to the
+ computed one.
+ """
+ np.random.seed(1) # Make it reproducible.
+ delta = 1e-3
+ radius = 2 * delta
+ low, high = -0.5, 1.5 # Also covers the case of extrapolation.
+
+ for image_height in range(1, 5):
+ for image_width in range(1, 3):
+ for crop_height in range(1, 3):
+ for crop_width in range(2, 4):
+ for depth in range(1, 3):
+ for num_boxes in range(1, 3):
+
+ batch = num_boxes
+ image_shape = [batch, image_height, image_width, depth]
+ crop_size = [crop_height, crop_width]
+ crops_shape = [num_boxes, crop_height, crop_width, depth]
+ boxes_shape = [num_boxes, 4]
+
+ image = np.arange(0, batch * image_height * image_width *
+ depth).reshape(image_shape).astype(np.float32)
+ boxes = []
+ for _ in range(num_boxes):
+ # pylint: disable=unbalanced-tuple-unpacking
+ y1, y2 = self._randomUniformAvoidAnchors(
+ low, high, np.linspace(0, 1, image_height), radius, 2)
+ x1, x2 = self._randomUniformAvoidAnchors(
+ low, high, np.linspace(0, 1, image_width), radius, 2)
+ # pylint: enable=unbalanced-tuple-unpacking
+ boxes.append([y1, x1, y2, x2])
+
+ boxes = np.array(boxes, dtype=np.float32)
+ box_ind = np.arange(batch, dtype=np.int32)
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ image_tensor = tf.constant(image, shape=image_shape)
+ boxes_tensor = tf.constant(boxes, shape=[num_boxes, 4])
+ box_ind_tensor = tf.constant(box_ind, shape=[num_boxes])
+ crops = tf.image.crop_and_resize(
+ image_tensor,
+ boxes_tensor,
+ box_ind_tensor,
+ tf.constant(crop_size, shape=[2]))
+
+ err = tf.test.compute_gradient_error(
+ [image_tensor, boxes_tensor],
+ [image_shape, boxes_shape],
+ crops,
+ crops_shape,
+ delta=delta,
+ x_init_value=[image, boxes])
+
+ self.assertLess(err, 2e-3)
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 2df7b9c4c7..a72690c0f2 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -73,6 +73,8 @@ resized_image = tf.image.resize_images(image, 299, 299)
@@crop_to_bounding_box
@@extract_glimpse
+@@crop_and_resize
+
## Flipping and Transposing
@@flip_up_down
@@ -1327,8 +1329,8 @@ def _random_crop_shape(op):
return [tensor_shape.TensorShape(output_shape)]
-@ops.RegisterShape("ExtractGlimpse")
-def _ExtractGlimpseShape(op):
+@ops.RegisterShape('ExtractGlimpse')
+def _extract_glimpse_shape(op):
"""Shape function for ExtractGlimpse op."""
input_shape = op.inputs[0].get_shape().with_rank(4)
unused_size_shape = op.inputs[1].get_shape().merge_with(
@@ -1347,6 +1349,22 @@ def _ExtractGlimpseShape(op):
[input_shape[0], height, width, input_shape[3]])]
+@ops.RegisterShape('CropAndResize')
+def _crop_and_resize_shape(op):
+ """Shape function for the CropAndResize op."""
+ image_shape = op.inputs[0].get_shape().with_rank(4)
+ box_shape = op.inputs[1].get_shape().with_rank(2)
+ crop_size = tensor_util.constant_value(op.inputs[3])
+ if crop_size is not None:
+ crop_height = crop_size[0]
+ crop_width = crop_size[1]
+ else:
+ crop_height = None
+ crop_width = None
+ return [tensor_shape.TensorShape(
+ [box_shape[0], crop_height, crop_width, image_shape[3]])]
+
+
__all__ = make_all(__name__)
# ResizeMethod is not documented, but is documented in functions that use it.
__all__.append('ResizeMethod')
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 9397b83e7c..d27cefc61d 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -239,7 +239,7 @@ def abs(x, name=None):
number.
Args:
- x: A `Tensor` or `SparseTensor` of type `float`, `double`, `int32`, or
+ x: A `Tensor` or `SparseTensor` of type `float32`, `float64`, `int32`, or
`int64`.
name: A name for the operation (optional).
@@ -352,7 +352,7 @@ def complex_abs(x, name=None):
r"""Computes the complex absolute value of a tensor.
Given a tensor `x` of complex numbers, this operation returns a tensor of type
- `float` or `double` that is the absolute value of each element in `x`. All
+ `float32` or `float64` that is the absolute value of each element in `x`. All
elements in `x` must be complex numbers of the form \\(a + bj\\). The
absolute value is computed as \\( \sqrt{a^2 + b^2}\\).
@@ -414,10 +414,10 @@ def pow(x, y, name=None):
```
Args:
- x: A `Tensor` of type `float`, `double`, `int32`, `int64`, `complex64`, or
- `complex128`.
- y: A `Tensor` of type `float`, `double`, `int32`, `int64`, `complex64`, or
- `complex128`.
+ x: A `Tensor` of type `float32`, `float64`, `int32`, `int64`, `complex64`,
+ or `complex128`.
+ y: A `Tensor` of type `float32`, `float64`, `int32`, `int64`, `complex64`,
+ or `complex128`.
name: A name for the operation (optional).
Returns:
@@ -471,7 +471,7 @@ def real(input, name=None):
"""Returns the real part of a complex number.
Given a tensor `input` of complex numbers, this operation returns a tensor of
- type `float` or `double` that is the real part of each element in `input`.
+ type `float32` or `float64` that is the real part of each element in `input`.
All elements in `input` must be complex numbers of the form \\(a + bj\\),
where *a* is the real part returned by this operation and *b* is the
imaginary part.
@@ -489,7 +489,7 @@ def real(input, name=None):
name: A name for the operation (optional).
Returns:
- A `Tensor` of type `float` or `double`.
+ A `Tensor` of type `float32` or `float64`.
"""
with ops.op_scope([input], name, "Real") as name:
return gen_math_ops.real(input, Tout=input.dtype.real_dtype, name=name)
@@ -499,7 +499,7 @@ def imag(input, name=None):
"""Returns the imaginary part of a complex number.
Given a tensor `input` of complex numbers, this operation returns a tensor of
- type `float` or `double` that is the imaginary part of each element in
+ type `float32` or `float64` that is the imaginary part of each element in
`input`. All elements in `input` must be complex numbers of the form \\(a +
bj\\), where *a* is the real part and *b* is the imaginary part returned by
this operation.
@@ -516,7 +516,7 @@ def imag(input, name=None):
name: A name for the operation (optional).
Returns:
- A `Tensor` of type `float` or `double`.
+ A `Tensor` of type `float32` or `float64`.
"""
with ops.op_scope([input], name, "Imag") as name:
return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name)
@@ -533,7 +533,7 @@ def round(x, name=None):
```
Args:
- x: A `Tensor` of type `float` or `double`.
+ x: A `Tensor` of type `float32` or `float64`.
name: A name for the operation (optional).
Returns:
@@ -1255,7 +1255,7 @@ def matmul(a, b,
possibly after transposition.
Both matrices must be of the same type. The supported types are:
- `float`, `double`, `int32`, `complex64`.
+ `float32`, `float64`, `int32`, `complex64`.
Either matrix can be transposed on the fly by setting the corresponding flag
to `True`. This is `False` by default.
@@ -1279,7 +1279,7 @@ def matmul(a, b,
```
Args:
- a: `Tensor` of type `float`, `double`, `int32` or `complex64`.
+ a: `Tensor` of type `float32`, `float64`, `int32` or `complex64`.
b: `Tensor` with same type as `a`.
transpose_a: If `True`, `a` is transposed before multiplication.
transpose_b: If `True`, `b` is transposed before multiplication.
@@ -1531,7 +1531,7 @@ def sigmoid(x, name=None):
Specifically, `y = 1 / (1 + exp(-x))`.
Args:
- x: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
+ x: A Tensor with type `float32`, `float64`, `int32`, `complex64`, `int64`,
or `qint32`.
name: A name for the operation (optional).
@@ -1548,7 +1548,7 @@ def tanh(x, name=None):
"""Computes hyperbolic tangent of `x` element-wise.
Args:
- x: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
+ x: A Tensor with type `float32`, `float64`, `int32`, `complex64`, `int64`,
or `qint32`.
name: A name for the operation (optional).
diff --git a/tensorflow/python/ops/nn_conv_test.py b/tensorflow/python/ops/nn_conv_test.py
index 465e39350f..8c771fc2a3 100644
--- a/tensorflow/python/ops/nn_conv_test.py
+++ b/tensorflow/python/ops/nn_conv_test.py
@@ -283,11 +283,42 @@ class Conv2DTransposeTest(tf.test.TestCase):
padding="SAME")
err = tf.test.compute_gradient_error(
[x, f], [x_shape, f_shape], output, y_shape)
- print("DeConv gradient err = %g " % err)
+ print("conv2d_transpose gradient err = %g " % err)
err_tolerance = 0.0005
self.assertLess(err, err_tolerance)
+class Conv2DBackpropFilterGradTest(tf.test.TestCase):
+
+ def testGradient(self):
+ with self.test_session():
+ for padding in ["SAME", "VALID"]:
+ for stride in [1, 2]:
+ np.random.seed(1)
+ in_shape = [5, 8, 6, 4]
+ in_val = tf.constant(
+ 2 * np.random.random_sample(in_shape) - 1,
+ dtype=tf.float32)
+ filter_shape = [3, 3, 4, 6]
+ # Make a convolution op with the current settings, just to easily get
+ # the shape of the output.
+ conv_out = tf.nn.conv2d(in_val, tf.zeros(filter_shape),
+ [1, stride, stride, 1], padding)
+ out_backprop_shape = conv_out.get_shape().as_list()
+ out_backprop_val = tf.constant(
+ 2 * np.random.random_sample(out_backprop_shape) - 1,
+ dtype=tf.float32)
+ output = tf.nn.conv2d_backprop_filter(in_val, filter_shape,
+ out_backprop_val,
+ [1, stride, stride, 1], padding)
+ err = tf.test.compute_gradient_error([in_val, out_backprop_val],
+ [in_shape, out_backprop_shape],
+ output, filter_shape)
+ print("conv2d_backprop_filter gradient err = %g " % err)
+ err_tolerance = 1e-3
+ self.assertLess(err, err_tolerance)
+
+
class Conv1DTest(tf.test.TestCase):
def testBasic(self):
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index ba2b9d957e..3f4bb0e068 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_nn_ops
@ops.RegisterGradient("Conv2DBackpropInput")
-def _Conv2DBackpropGrad(op, grad):
+def _Conv2DBackpropInputGrad(op, grad):
"""The derivatives for deconvolution.
Args:
@@ -49,6 +49,25 @@ def _Conv2DBackpropGrad(op, grad):
op.get_attr("data_format"))]
+@ops.RegisterGradient("Conv2DBackpropFilter")
+def _Conv2DBackpropFilterGrad(op, grad):
+ return [
+ nn_ops.conv2d_backprop_input(
+ array_ops.shape(op.inputs[0]), grad, op.inputs[2],
+ op.get_attr("strides"),
+ op.get_attr("padding"),
+ op.get_attr("use_cudnn_on_gpu"),
+ op.get_attr("data_format")),
+ None,
+ nn_ops.conv2d(
+ op.inputs[0], grad,
+ op.get_attr("strides"),
+ op.get_attr("padding"),
+ op.get_attr("use_cudnn_on_gpu"),
+ op.get_attr("data_format"))
+ ]
+
+
@ops.RegisterGradient("Conv3D")
def _Conv3DGrad(op, grad):
return [nn_ops.conv3d_backprop_input_v2(array_ops.shape(op.inputs[0]),
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index d8a96db0a2..d2832541dd 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -469,6 +469,125 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs,
return (outputs, output_state_fw, output_state_bw)
+def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
+ initial_state_fw=None, initial_state_bw=None,
+ dtype=None, parallel_iterations=None,
+ swap_memory=False, time_major=False, scope=None):
+ """Creates a dynamic version of bidirectional recurrent neural network.
+
+ Similar to the unidirectional case above (rnn) but takes input and builds
+ independent forward and backward RNNs. The input_size of forward and
+ backward cell must match. The initial state for both directions is zero by
+ default (but can be set optionally) and no intermediate states are ever
+ returned -- the network is fully unrolled for the given (passed in)
+ length(s) of the sequence(s) or completely unrolled if length(s) is not
+ given.
+
+ Args:
+ cell_fw: An instance of RNNCell, to be used for forward direction.
+ cell_bw: An instance of RNNCell, to be used for backward direction.
+ inputs: The RNN inputs.
+ If time_major == False (default), this must be a tensor of shape:
+ `[batch_size, max_time, input_size]`.
+ If time_major == True, this must be a tensor of shape:
+ `[max_time, batch_size, input_size]`.
+ [batch_size, input_size].
+ sequence_length: An int32/int64 vector, size `[batch_size]`,
+ containing the actual lengths for each of the sequences.
+ initial_state_fw: (optional) An initial state for the forward RNN.
+ This must be a tensor of appropriate type and shape
+ `[batch_size x cell_fw.state_size]`.
+ If `cell_fw.state_size` is a tuple, this should be a tuple of
+ tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
+ initial_state_bw: (optional) Same as for `initial_state_fw`, but using
+ the corresponding properties of `cell_bw`.
+ parallel_iterations: (Default: 32). The number of iterations to run in
+ parallel. Those operations which do not have any temporal dependency
+ and can be run in parallel, will be. This parameter trades off
+ time for space. Values >> 1 use more memory but take less time,
+ while smaller values use less memory but computations take longer.
+ swap_memory: Transparently swap the tensors produced in forward inference
+ but needed for back prop from GPU to CPU. This allows training RNNs
+ which would typically not fit on a single GPU, with very minimal (or no)
+ performance penalty.
+ time_major: The shape format of the `inputs` and `outputs` Tensors.
+ If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
+ If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
+ Using `time_major = True` is a bit more efficient because it avoids
+ transposes at the beginning and end of the RNN calculation. However,
+ most TensorFlow data is batch-major, so by default this function
+ accepts input and emits output in batch-major form.
+ dtype: (optional) The data type for the initial state. Required if
+ initial_state is not provided.
+ sequence_length: An int32/int64 vector, size `[batch_size]`,
+ containing the actual lengths for each of the sequences.
+ either of the initial states are not provided.
+ scope: VariableScope for the created subgraph; defaults to "BiRNN"
+
+ Returns:
+ A tuple (outputs, output_states) where:
+ outputs: A tuple (output_fw, output_bw) containing the forward and
+ the backward rnn output `Tensor`.
+ If time_major == False (default),
+ output_fw will be a `Tensor` shaped:
+ `[batch_size, max_time, cell_fw.output_size]`
+ and output_bw will be a `Tensor` shaped:
+ `[batch_size, max_time, cell_bw.output_size]`.
+ If time_major == True,
+ output_fw will be a `Tensor` shaped:
+ `[max_time, batch_size, cell_fw.output_size]`
+ and output_bw will be a `Tensor` shaped:
+ `[max_time, batch_size, cell_bw.output_size]`.
+ It returns a tuple instead of a single concatenated `Tensor`, unlike
+ in the `bidirectional_rnn`. If the concatenated one is preferred,
+ the forward and backward outputs can be concatenated as
+ `tf.concat(2, outputs)`.
+ output_states: A tuple (output_state_fw, output_state_bw) containing
+ the forward and the backward final states of bidirectional rnn.
+
+ Raises:
+ TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
+ """
+
+ if not isinstance(cell_fw, rnn_cell.RNNCell):
+ raise TypeError("cell_fw must be an instance of RNNCell")
+ if not isinstance(cell_bw, rnn_cell.RNNCell):
+ raise TypeError("cell_bw must be an instance of RNNCell")
+
+ name = scope or "BiRNN"
+ # Forward direction
+ with vs.variable_scope(name + "_FW") as fw_scope:
+ output_fw, output_state_fw = dynamic_rnn(
+ cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
+ initial_state=initial_state_fw, dtype=dtype,
+ parallel_iterations=parallel_iterations, swap_memory=swap_memory,
+ time_major=time_major, scope=fw_scope)
+ # Backward direction
+ if not time_major:
+ time_dim = 1
+ batch_dim = 0
+ else:
+ time_dim = 0
+ batch_dim = 1
+ with vs.variable_scope(name + "_BW") as bw_scope:
+ inputs_reverse = array_ops.reverse_sequence(
+ input=inputs, seq_lengths=sequence_length,
+ seq_dim=time_dim, batch_dim=batch_dim)
+ tmp, output_state_bw = dynamic_rnn(
+ cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
+ initial_state=initial_state_bw, dtype=dtype,
+ parallel_iterations=parallel_iterations, swap_memory=swap_memory,
+ time_major=time_major, scope=bw_scope)
+ output_bw = array_ops.reverse_sequence(
+ input=tmp, seq_lengths=sequence_length,
+ seq_dim = time_dim, batch_dim=batch_dim)
+
+ outputs = (output_fw, output_bw)
+ output_states = (output_state_fw, output_state_bw)
+
+ return (outputs, output_states)
+
+
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
dtype=None, parallel_iterations=None, swap_memory=False,
time_major=False, scope=None):
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index efb1fc919f..4be42ce1cb 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -877,6 +877,15 @@ def local_variables():
return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
+def model_variables():
+ """Returns all variables in the MODEL_VARIABLES collection.
+
+ Returns:
+ A list of local Variable objects.
+ """
+ return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES)
+
+
def moving_average_variables():
"""Returns all variables that maintain their moving averages.
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index 63bac26b52..0274d62c48 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -59,6 +59,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.client import device_lib
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.util.all_util import make_all
@@ -97,6 +98,11 @@ def is_built_with_cuda():
return test_util.IsGoogleCudaEnabled()
+def is_gpu_available():
+ """Returns whether TensorFlow can access a GPU."""
+ return any(x.device_type() == 'GPU' for x in device_lib.list_local_devices())
+
+
__all__ = make_all(__name__)
# TODO(irving,vrv): Remove once TestCase is documented
__all__.append('TestCase')
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 6ecd20ad48..9cc1d917ab 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -266,8 +266,10 @@ class ExponentialMovingAverage(object):
if var_list is None:
var_list = variables.trainable_variables()
for var in var_list:
- if var.dtype.base_dtype not in [dtypes.float32, dtypes.float64]:
- raise TypeError("The variables must be float or double: %s" % var.name)
+ if var.dtype.base_dtype not in [dtypes.float16, dtypes.float32,
+ dtypes.float64]:
+ raise TypeError("The variables must be half, float, or double: %s" %
+ var.name)
if var in self._averages:
raise ValueError("Moving average already computed for: %s" % var.name)
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 2a943b5911..1d6a8a6632 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -285,7 +285,7 @@ class BaseSaverBuilder(object):
def _AddShardedRestoreOps(self, filename_tensor, per_device,
restore_sequentially, reshape):
- """Add Ops to save variables from multiple devices.
+ """Add Ops to restore variables from multiple devices.
Args:
filename_tensor: Tensor for the path of the file to load.
diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py
index d939b5aabd..94fa6c295f 100644
--- a/tensorflow/python/training/server_lib_test.py
+++ b/tensorflow/python/training/server_lib_test.py
@@ -52,6 +52,154 @@ class GrpcServerTest(tf.test.TestCase):
sess_2.close()
# TODO(mrry): Add `server.stop()` and `server.join()` when these work.
+ # Verifies behavior of multiple variables with multiple sessions connecting to
+ # the same server.
+ def testSameVariablesNoClear(self):
+ server = tf.train.Server.create_local_server()
+
+ with tf.Session(server.target) as sess_1:
+ v0 = tf.Variable([[2, 1]], name="v0")
+ v1 = tf.Variable([[1], [2]], name="v1")
+ v2 = tf.matmul(v0, v1)
+ sess_1.run([v0.initializer, v1.initializer])
+ self.assertAllEqual([[4]], sess_1.run(v2))
+
+ with tf.Session(server.target) as sess_2:
+ new_v0 = tf.get_default_graph().get_tensor_by_name("v0:0")
+ new_v1 = tf.get_default_graph().get_tensor_by_name("v1:0")
+ new_v2 = tf.matmul(new_v0, new_v1)
+ self.assertAllEqual([[4]], sess_2.run(new_v2))
+
+ # Verifies behavior of tf.Session.reset().
+ def testSameVariablesClear(self):
+ server = tf.train.Server.create_local_server()
+
+ # Creates a graph with 2 variables.
+ v0 = tf.Variable([[2, 1]], name="v0")
+ v1 = tf.Variable([[1], [2]], name="v1")
+ v2 = tf.matmul(v0, v1)
+
+ # Verifies that both sessions connecting to the same target return
+ # the same results.
+ sess_1 = tf.Session(server.target)
+ sess_2 = tf.Session(server.target)
+ sess_1.run(tf.initialize_all_variables())
+ self.assertAllEqual([[4]], sess_1.run(v2))
+ self.assertAllEqual([[4]], sess_2.run(v2))
+
+ # Resets target. sessions abort. Use sess_2 to verify.
+ tf.Session.reset(server.target)
+ with self.assertRaises(tf.errors.AbortedError):
+ self.assertAllEqual([[4]], sess_2.run(v2))
+
+ # Connects to the same target. Device memory for the variables would have
+ # been released, so they will be unitialized.
+ sess_2 = tf.Session(server.target)
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ sess_2.run(v2)
+ # Reinitialzes the variables.
+ sess_2.run(tf.initialize_all_variables())
+ self.assertAllEqual([[4]], sess_2.run(v2))
+ sess_2.close()
+
+ # Verifies behavior of tf.Session.reset() with multiple containers using
+ # default container names as defined by the target name.
+ def testSameVariablesClearContainer(self):
+ # Starts two servers with different names so they map to different
+ # resource "containers".
+ server0 = tf.train.Server({"local0": ["localhost:0"]}, protocol="grpc",
+ start=True)
+ server1 = tf.train.Server({"local1": ["localhost:0"]}, protocol="grpc",
+ start=True)
+
+ # Creates a graph with 2 variables.
+ v0 = tf.Variable(1.0, name="v0")
+ v1 = tf.Variable(2.0, name="v0")
+
+ # Initializes the variables. Verifies that the values are correct.
+ sess_0 = tf.Session(server0.target)
+ sess_1 = tf.Session(server1.target)
+ sess_0.run(v0.initializer)
+ sess_1.run(v1.initializer)
+ self.assertAllEqual(1.0, sess_0.run(v0))
+ self.assertAllEqual(2.0, sess_1.run(v1))
+
+ # Resets container "local0". Verifies that v0 is no longer initialized.
+ tf.Session.reset(server0.target, ["local0"])
+ sess = tf.Session(server0.target)
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ sess.run(v0)
+ # Reinitializes v0 for the following test.
+ sess.run(v0.initializer)
+
+ # Verifies that v1 is still valid.
+ self.assertAllEqual(2.0, sess_1.run(v1))
+
+ # Resets container "local1". Verifies that v1 is no longer initialized.
+ tf.Session.reset(server1.target, ["local1"])
+ sess = tf.Session(server1.target)
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ sess.run(v1)
+ # Verifies that v0 is still valid.
+ sess = tf.Session(server0.target)
+ self.assertAllEqual(1.0, sess.run(v0))
+
+ # Verifies behavior of tf.Session.reset() with multiple containers using
+ # tf.container.
+ def testMultipleContainers(self):
+ with tf.container("test0"):
+ v0 = tf.Variable(1.0, name="v0")
+ with tf.container("test1"):
+ v1 = tf.Variable(2.0, name="v0")
+ server = tf.train.Server.create_local_server()
+ sess = tf.Session(server.target)
+ sess.run(tf.initialize_all_variables())
+ self.assertAllEqual(1.0, sess.run(v0))
+ self.assertAllEqual(2.0, sess.run(v1))
+
+ # Resets container. Session aborts.
+ tf.Session.reset(server.target, ["test0"])
+ with self.assertRaises(tf.errors.AbortedError):
+ sess.run(v1)
+
+ # Connects to the same target. Device memory for the v0 would have
+ # been released, so it will be unitialized. But v1 should still
+ # be valid.
+ sess = tf.Session(server.target)
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ sess.run(v0)
+ self.assertAllEqual(2.0, sess.run(v1))
+
+ # Verifies various reset failures.
+ def testResetFails(self):
+ # Creates variable with container name.
+ with tf.container("test0"):
+ v0 = tf.Variable(1.0, name="v0")
+ # Creates variable with default container.
+ v1 = tf.Variable(2.0, name="v1")
+ # Verifies resetting the non-existent target returns error.
+ with self.assertRaises(tf.errors.NotFoundError):
+ tf.Session.reset("nonexistent", ["test0"])
+
+ # Verifies resetting with config.
+ # Verifies that resetting target with no server times out.
+ with self.assertRaises(tf.errors.DeadlineExceededError):
+ tf.Session.reset("grpc://localhost:0", ["test0"],
+ config=tf.ConfigProto(operation_timeout_in_ms=5))
+
+ # Verifies no containers are reset with non-existent container.
+ server = tf.train.Server.create_local_server()
+ sess = tf.Session(server.target)
+ sess.run(tf.initialize_all_variables())
+ self.assertAllEqual(1.0, sess.run(v0))
+ self.assertAllEqual(2.0, sess.run(v1))
+ # No container is reset, but the server is reset.
+ tf.Session.reset(server.target, ["test1"])
+ # Verifies that both variables are still valid.
+ sess = tf.Session(server.target)
+ self.assertAllEqual(1.0, sess.run(v0))
+ self.assertAllEqual(2.0, sess.run(v1))
+
def testLargeConstant(self):
server = tf.train.Server.create_local_server()
with tf.Session(server.target) as sess: