diff options
author | 2016-06-28 11:13:44 +0800 | |
---|---|---|
committer | 2016-06-28 11:13:44 +0800 | |
commit | 43925959ebcbf6eb5e48d8854c9550a109961aa7 (patch) | |
tree | e2829db19176c1cd3dcc0d56caa5105d1024dd0c /tensorflow/python | |
parent | 19491230eed62fae0bdaf5c02417af389b314658 (diff) | |
parent | 14b8ed02dbd4da8fd7a269fa6a5fef5abe405489 (diff) |
Diffstat (limited to 'tensorflow/python')
36 files changed, 1718 insertions, 122 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e088fd61d1..ef7754a920 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -253,6 +253,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":device_lib", ":framework_test_lib", ":platform_test", ], @@ -543,10 +544,12 @@ tf_gen_op_wrapper_py( "HashTable", "InitializeTable", "InitializeTableFromTextFile", + "LookupTableExport", "LookupTableFind", "LookupTableInsert", "LookupTableSize", "MutableHashTable", + "MutableHashTableOfTensors", "Mutex", "MutexAcquire", "MutexRelease", diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 6dd213624b..a7de16de0f 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -853,6 +853,8 @@ class Session(BaseSession): @@as_default + @@reset + """ def __init__(self, target='', graph=None, config=None): @@ -892,6 +894,34 @@ class Session(BaseSession): self.close() + @staticmethod + def reset(target, containers=None, config=None): + """Resets resource containers on `target`, and close all connected sessions. + + A resource container is distributed across all workers in the + same cluster as `target`. When a resource container on `target` + is reset, resources associated with that container will be cleared. + In particular, all Variables in the container will become undefined: + they lose their values and shapes. + + NOTE: + (i) reset() is currently only implemented for distributed sessions. + (ii) Any sessions on the master named by `target` will be closed. + + If no resource containers are provided, all containers are reset. + + Args: + target: The execution engine to connect to. + containers: A list of resource container name strings, or `None` if all of + all the containers are to be reset. + config: (Optional.) Protocol buffer with configuration options. + + Raises: + tf.errors.OpError: Or one of its subclasses if an error occurs while + resetting containers. + """ + tf_session.TF_Reset(target, containers, config) + class InteractiveSession(BaseSession): """A TensorFlow `Session` for use in interactive contexts, such as a shell. diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 6707542178..01c3bf7e51 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -290,6 +290,22 @@ tensorflow::ImportNumpy(); %unignore tensorflow; %unignore TF_PRun; +%unignore tensorflow::TF_Reset_wrapper; +%insert("python") %{ +def TF_Reset(target, containers=None, config=None): + from tensorflow.python.framework import errors + try: + opts = TF_NewSessionOptions(target=target, config=config) + with errors.raise_exception_on_not_ok_status() as status: + from tensorflow.python.util import compat + if containers is None: + containers = [] + TF_Reset_wrapper( + opts, [compat.as_bytes(c) for c in containers], status) + finally: + TF_DeleteSessionOptions(opts) +%} + %include "tensorflow/python/client/tf_session_helper.h" %unignoreall diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index e9b6b44613..31f4ad6ba9 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -573,6 +573,13 @@ void TF_PRun_wrapper(TF_Session* session, const char* handle, NameVector(), out_status, out_values, nullptr); } +// Wrapper for TF_Reset that converts the string vectors to character arrays. +void TF_Reset_wrapper(const TF_SessionOptions* opt, + const NameVector& containers, TF_Status* out_status) { + TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(), + out_status); +} + string EqualGraphDefWrapper(const string& actual, const string& expected) { GraphDef actual_def; if (!actual_def.ParseFromString(actual)) { diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index d360efa8f3..285bd902a4 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -96,6 +96,10 @@ void TF_PRun_wrapper(TF_Session* session, const char* handle, const FeedVector& inputs, const NameVector& output_names, TF_Status* out_status, PyObjectVector* out_values); +// Wrapper for TF_Reset that converts the string vectors to character arrays. +void TF_Reset_wrapper(const TF_SessionOptions* opt, + const NameVector& containers, TF_Status* out_status); + // Convenience wrapper around EqualGraphDef to make it easier to wrap. // Returns an explanation if a difference is found, or the empty string // for no difference. diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index d45ab2f2e2..1f29426b4c 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -205,7 +205,7 @@ class DType(object): (bool, string, complex64, complex128)): raise TypeError("Cannot find maximum value of %s." % self) - # there is no simple way to get the min value of a dtype, we have to check + # there is no simple way to get the max value of a dtype, we have to check # float and int types separately try: return np.finfo(self.as_numpy_dtype()).max diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py index 480e99abae..b06605cf59 100644 --- a/tensorflow/python/framework/framework_lib.py +++ b/tensorflow/python/framework/framework_lib.py @@ -30,6 +30,7 @@ ## Utility functions @@device +@@container @@name_scope @@control_dependencies @@convert_to_tensor @@ -77,6 +78,7 @@ from tensorflow.python.framework.ops import IndexedSlices # Utilities used when building a Graph. from tensorflow.python.framework.ops import device +from tensorflow.python.framework.ops import container from tensorflow.python.framework.ops import name_scope from tensorflow.python.framework.ops import op_scope from tensorflow.python.framework.ops import control_dependencies diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 912928778b..b7d131cf78 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -35,6 +35,7 @@ from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import versions_pb2 from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes +from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions @@ -1984,6 +1985,9 @@ class Graph(object): self._handle_movers = {} # A map from tensor handle to its delete op. self._handle_deleters = {} + # Resource container. + self._container = "" + self._registered_ops = op_def_registry.get_registered_ops() def _check_not_finalized(self): """Check if the graph is finalized. @@ -2320,6 +2324,19 @@ class Graph(object): ret.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) + # Sets "container" attribute if + # (1) self._container is not None + # (2) "is_stateful" is set in OpDef + # (3) "container" attribute is in OpDef + # (4) "container" attribute is None + if (self._container and + op_type in self._registered_ops and + self._registered_ops[op_type].is_stateful and + "container" in ret.node_def.attr and + not ret.node_def.attr["container"].s): + ret.node_def.attr["container"].CopyFrom( + attr_value_pb2.AttrValue(s=compat.as_bytes(self._container))) + return ret def as_graph_element(self, obj, allow_tensor=True, allow_operation=True): @@ -2958,6 +2975,60 @@ class Graph(object): break op._set_device(device_function(op)) + # pylint: disable=g-doc-return-or-yield + @contextlib.contextmanager + def container(self, container_name): + """Returns a context manager that specifies the resource container to use. + + Stateful operations, such as variables and queues, can maintain their + states on devices so that they can be shared by multiple processes. + A resource container is a string name under which these stateful + operations are tracked. These resources can be released or cleared + with `tf.Session.reset()`. + + For example: + + ```python + with g.container('experiment0'): + # All stateful Operations constructed in this context will be placed + # in resource container "experiment0". + v1 = tf.Variable([1.0]) + v2 = tf.Variable([2.0]) + with g.container("experiment1"): + # All stateful Operations constructed in this context will be + # placed in resource container "experiment1". + v3 = tf.Variable([3.0]) + q1 = tf.FIFOQueue(10, tf.float32) + # All stateful Operations constructed in this context will be + # be created in the "experiment0". + v4 = tf.Variable([4.0]) + q1 = tf.FIFOQueue(20, tf.float32) + with g.container(""): + # All stateful Operations constructed in this context will be + # be placed in the default resource container. + v5 = tf.Variable([5.0]) + q3 = tf.FIFOQueue(30, tf.float32) + + # Resets container "experiment0", after which the state of v1, v2, v4, q1 + # will become undefined (such as unitialized). + tf.Session.reset(target, ["experiment0"]) + ``` + + Args: + container_name: container name string. + + Returns: + A context manager for defining resource containers for stateful ops, + yields the container name. + """ + original_container = self._container + try: + self._container = container_name + yield self._container + finally: + self._container = original_container + # pylint: enable=g-doc-return-or-yield + class _ControlDependenciesController(object): """Context manager for `control_dependencies()`.""" @@ -3335,6 +3406,19 @@ def device(device_name_or_function): return get_default_graph().device(device_name_or_function) +def container(container_name): + """Wrapper for `Graph.container()` using the default graph. + + Args: + container_name: The container string to use in the context. + + Returns: + A context manager that specifies the default container to use for newly + created stateful ops. + """ + return get_default_graph().container(container_name) + + def colocate_with(op, ignore_existing=False): return get_default_graph().colocate_with(op, ignore_existing) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index a213dc2262..098558fd3e 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -628,3 +628,55 @@ def constant_value(tensor): # conservatively prevent it from being fed. tensor.graph.prevent_feeding(tensor) return ret + + +def constant_value_as_shape(tensor): # pylint: disable=invalid-name + """A version of `constant_value()` that returns a `TensorShape`. + + This version should be used when a constant tensor value is + interpreted as a (possibly partial) shape, e.g. in the shape + function for `tf.reshape()`. By explicitly requesting a + `TensorShape` as the return value, it is possible to represent + unknown dimensions; by contrast, `constant_value()` is + all-or-nothing. + + Args: + tensor: The rank-1 Tensor to be evaluated. + + Returns: + A `TensorShape` based on the constant value of the given `tensor`. + """ + shape = tensor.get_shape().with_rank(1) + if tensor.get_shape() == [0]: + return tensor_shape.scalar() + elif tensor.op.type == "Shape": + return tensor.op.inputs[0].get_shape() + elif tensor.op.type == "Pack": + ret = tensor_shape.scalar() # Empty list. + for pack_input in tensor.op.inputs: + # `pack_input` must be a scalar. Attempt to evaluate it, and append it + # to `ret`. + pack_input_val = constant_value(pack_input) + if pack_input_val is None or pack_input_val < 0: + new_dim = tensor_shape.Dimension(None) + else: + new_dim = tensor_shape.Dimension(pack_input_val) + ret = ret.concatenate([new_dim]) + return ret + elif tensor.op.type == "Concat": + # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is + # the only legal value when concatenating vectors, and it will + # have been checked by a previous shape function. + ret = tensor_shape.scalar() # Empty list. + for concat_input in tensor.op.inputs[1:]: + # `concat_input` must be a vector. Attempt to evaluate it as a shape, + # and concatenate it with `ret`. + ret = ret.concatenate(constant_value_as_shape(concat_input)) + return ret + else: + ret = tensor_shape.unknown_shape(shape[0].value) + value = constant_value(tensor) + if value is not None: + ret = ret.merge_with(tensor_shape.TensorShape( + [d if d != -1 else None for d in value])) + return ret diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index c56c5e948b..3fdcba5416 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -565,5 +565,38 @@ class ConstantValueTest(tf.test.TestCase): self.assertIs(None, c_val) +class ConstantValueAsShapeTest(tf.test.TestCase): + + def testConstant(self): + np_val = np.random.rand(3).astype(np.int32) + tf_val = tf.constant(np_val) + self.assertEqual(tf.TensorShape(np_val), + tensor_util.constant_value_as_shape(tf_val)) + + tf_val = tf.constant([], dtype=tf.int32) + self.assertEqual(tf.TensorShape([]), + tensor_util.constant_value_as_shape(tf_val)) + + def testShape(self): + tf_val = tf.shape(tf.constant(0.0, shape=[1, 2, 3])) + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual(tf.TensorShape([1, 2, 3]), c_val) + + def testPack(self): + tf_val = tf.pack([tf.constant(16), 37, tf.placeholder(tf.int32)]) + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([16, 37, None], c_val.as_list()) + + def testConcat(self): + tf_val = tf.concat(0, [[16, 37], tf.placeholder(tf.int32, shape=(2,))]) + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([16, 37, None, None], c_val.as_list()) + + tf_val = tf.concat(0, + [[16, 37], tf.placeholder(tf.int32, shape=(1,)), [48]]) + c_val = tensor_util.constant_value_as_shape(tf_val) + self.assertEqual([16, 37, None, 48], c_val.as_list()) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 4353de4007..8435d6b9cd 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import math +import time import numpy as np import tensorflow as tf @@ -28,7 +29,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.platform import googletest class BooleanMaskTest(test_util.TensorFlowTestCase): @@ -383,5 +383,106 @@ class StridedSliceShapeTest(test_util.TensorFlowTestCase): tensor_shape.TensorShape([5, None, 1, 4])) +class GradSliceChecker(object): + """Tests that we can compute a gradient for var^2.""" + + def __init__(self, test, sess, var, varnp): + self.test = test + self.sess = sess + self.val = var * var + self.var = var + self.varnp = varnp + + def __getitem__(self, spec): + val_grad_op = tf.gradients(self.val, self.var) + sliceval_grad_op = tf.gradients( + array_ops._NewSliceHelper(self.val, spec), self.var) + slice1_op = array_ops._NewSliceHelper(val_grad_op, spec) + slice2_op = array_ops._NewSliceHelper(sliceval_grad_op, spec) + val_grad, sliceval_grad, slice1, slice2 = self.sess.run( + [val_grad_op, sliceval_grad_op, slice1_op, slice2_op]) + np_val_grad = (2 * self.varnp) + np_sliceval_grad = np.zeros(self.var.get_shape()) + np_sliceval_grad[spec] = np.array(val_grad[0])[spec] + # make sure np val grad is correct + self.test.assertAllEqual(np_val_grad, val_grad[0]) + # make sure slice gradient is correct + self.test.assertAllEqual(np_sliceval_grad, sliceval_grad[0]) + # make sure val grad and sliceval grad are the same in sliced area + self.test.assertAllEqual(slice1, slice2) + + +class StridedSliceGradTest(test_util.TensorFlowTestCase): + """Test that strided slice's custom gradient produces correct gradients.""" + + def testGradient(self): + for use_gpu in [False, True]: + with self.test_session(use_gpu=use_gpu) as sess: + var = tf.Variable(tf.reshape(tf.range(1, 97, 1), shape=(6, 4, 4))) + init = tf.initialize_all_variables() + sess.run(init) + + grad = GradSliceChecker(self, sess, var, + np.array(range(1, 97, 1)).reshape((6, 4, 4))) + _ = grad[2:6:2, 1:3, 1:3] + _ = grad[3:0:-2, 1:3, 1:3] + _ = grad[3:0:-2, tf.newaxis, 1:3, 2, tf.newaxis] + _ = grad[3:0:-2, 1:3, 2] + + +class BenchmarkSlice(object): + + def __init__(self, tensor): + self.tensor = tensor + + def __getitem__(self, x): + return array_ops._NewSliceHelper(self.tensor, x) + + +class StridedSliceBenchmark(tf.test.Benchmark): + """Benchmark new strided slice operation on non-trivial case.""" + + def run_and_time(self, slice_op): + tf.initialize_all_variables().run() + for _ in range(10): + _ = slice_op.eval() + iters = 1000 + t0 = time.time() + for _ in range(iters): + slice_op.eval() + t1 = time.time() + self.report_benchmark(iters=iters, wall_time=(t1 - t0) / 1000.0) + + def make_variable(self): + n = 256 + shape = (n, n, n) + items = n**3 + var = tf.Variable( + tf.reshape( + tf.linspace(1., float(items), items), shape), + dtype=tf.float32) + return var + + def benchmark_strided_slice_skip(self): + with tf.Session(): + var = self.make_variable() + helper = BenchmarkSlice(var) + slice_op = helper[::2, ::1, ::2] + self.run_and_time(slice_op) + + def benchmark_strided_slice_easy(self): + with tf.Session(): + var = self.make_variable() + helper = BenchmarkSlice(var) + slice_op = helper[3::1, 3::1, 3::1] + self.run_and_time(slice_op) + + def benchmark_slice_easy(self): + with tf.Session(): + var = self.make_variable() + slice_op = var[3::1, 3::1, 3::1] + self.run_and_time(slice_op) + + if __name__ == "__main__": - googletest.main() + tf.test.main() diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 96fc24fba0..f1a1683332 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -950,6 +950,18 @@ class ControlFlowTest(tf.test.TestCase): self.assertEqual([None], r.get_shape().as_list()) self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]})) + def testWhileGrad_BaseShape(self): + with self.test_session() as sess: + x = tf.placeholder(tf.float32, [None]) + v0 = tf.constant([2.0, 2.0], name="v") + c = lambda v: tf.constant(False) + b = lambda v: tf.mul(v, x) + r = tf.while_loop(c, b, [v0]) + y = tf.square(x) + + r = tf.gradients([r, y], x)[0] + self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]})) + def testWhileGrad_MultipleUses(self): with self.test_session(): v = tf.constant(2.0, name="v") diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index 622459e545..8ec04d3e80 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import random import re import time - import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf @@ -293,6 +292,16 @@ class FIFOQueueTest(tf.test.TestCase): enqueue_op.run() self.assertEqual([], dequeued_t.eval().tolist()) + def testEmptyDequeueUpTo(self): + with self.test_session(): + q = tf.FIFOQueue(10, tf.float32, shapes=()) + enqueue_op = q.enqueue((10.0,)) + dequeued_t = q.dequeue_up_to(0) + + self.assertEqual([], dequeued_t.eval().tolist()) + enqueue_op.run() + self.assertEqual([], dequeued_t.eval().tolist()) + def testEmptyDequeueManyWithNoShape(self): with self.test_session(): q = tf.FIFOQueue(10, tf.float32) @@ -328,6 +337,18 @@ class FIFOQueueTest(tf.test.TestCase): self.assertAllEqual(elems[0:4], dequeued_t.eval()) self.assertAllEqual(elems[4:8], dequeued_t.eval()) + def testDequeueUpToNoBlocking(self): + with self.test_session(): + q = tf.FIFOQueue(10, tf.float32, ()) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_op = q.enqueue_many((elems,)) + dequeued_t = q.dequeue_up_to(4) + + enqueue_op.run() + + self.assertAllEqual(elems[0:4], dequeued_t.eval()) + self.assertAllEqual(elems[4:8], dequeued_t.eval()) + def testMultiDequeueMany(self): with self.test_session() as sess: q = tf.FIFOQueue(10, (tf.float32, tf.int32), @@ -358,6 +379,29 @@ class FIFOQueueTest(tf.test.TestCase): self.assertEqual(float_val.shape, dequeued_single_t[0].get_shape()) self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape()) + def testMultiDequeueUpToNoBlocking(self): + with self.test_session() as sess: + q = tf.FIFOQueue(10, (tf.float32, tf.int32), + shapes=((), (2,))) + float_elems = [ + 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + int_elems = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], + [11, 12], [13, 14], [15, 16], [17, 18], [19, 20]] + enqueue_op = q.enqueue_many((float_elems, int_elems)) + dequeued_t = q.dequeue_up_to(4) + + enqueue_op.run() + + float_val, int_val = sess.run(dequeued_t) + self.assertAllEqual(float_elems[0:4], float_val) + self.assertAllEqual(int_elems[0:4], int_val) + self.assertEqual([None], dequeued_t[0].get_shape().as_list()) + self.assertEqual([None, 2], dequeued_t[1].get_shape().as_list()) + + float_val, int_val = sess.run(dequeued_t) + self.assertAllEqual(float_elems[4:8], float_val) + self.assertAllEqual(int_elems[4:8], int_val) + def testHighDimension(self): with self.test_session(): q = tf.FIFOQueue(10, tf.int32, (4, 4, 4, 4)) @@ -469,6 +513,29 @@ class FIFOQueueTest(tf.test.TestCase): thread.join() self.assertItemsEqual(elems, dequeued_elems) + def testParallelDequeueUpTo(self): + with self.test_session() as sess: + q = tf.FIFOQueue(1000, tf.float32, shapes=()) + elems = [10.0 * x for x in range(1000)] + enqueue_op = q.enqueue_many((elems,)) + close_op = q.close() + dequeued_t = q.dequeue_up_to(101) + + enqueue_op.run() + close_op.run() + + # Dequeue up to 101 items in parallel on 10 threads, from closed queue. + dequeued_elems = [] + + def dequeue(): + dequeued_elems.extend(sess.run(dequeued_t)) + threads = [self.checkedThread(target=dequeue) for _ in range(10)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertItemsEqual(elems, dequeued_elems) + def testParallelEnqueueAndDequeue(self): with self.test_session() as sess: q = tf.FIFOQueue(50, tf.float32, shapes=()) @@ -596,6 +663,33 @@ class FIFOQueueTest(tf.test.TestCase): self.assertAllEqual(elems, dequeued_elems) + def testBlockingDequeueUpTo(self): + with self.test_session() as sess: + q = tf.FIFOQueue(10, tf.float32, ()) + elems = [10.0, 20.0, 30.0, 40.0] + enqueue_op = q.enqueue_many((elems,)) + dequeued_t = q.dequeue_up_to(4) + + dequeued_elems = [] + + def enqueue(): + # The enqueue_op should run after the dequeue op has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + sess.run(enqueue_op) + + def dequeue(): + dequeued_elems.extend(sess.run(dequeued_t).tolist()) + + enqueue_thread = self.checkedThread(target=enqueue) + dequeue_thread = self.checkedThread(target=dequeue) + enqueue_thread.start() + dequeue_thread.start() + enqueue_thread.join() + dequeue_thread.join() + + self.assertAllEqual(elems, dequeued_elems) + def testDequeueManyWithTensorParameter(self): with self.test_session(): # Define a first queue that contains integer counts. @@ -710,6 +804,53 @@ class FIFOQueueTest(tf.test.TestCase): close_op.run() dequeue_thread.join() + def testBlockingDequeueManyButNotAllFromClosedQueue(self): + with self.test_session() as sess: + q = tf.FIFOQueue(10, tf.float32, ()) + elems = [10.0, 20.0, 30.0, 40.0] + enqueue_op = q.enqueue_many((elems,)) + close_op = q.close() + dequeued_t = q.dequeue_many(3) + + enqueue_op.run() + + def dequeue(): + self.assertAllEqual(elems[:3], sess.run(dequeued_t)) + # Expect the operation to fail due to the queue being closed. + with self.assertRaisesRegexp(tf.errors.OutOfRangeError, + "is closed and has insufficient"): + sess.run(dequeued_t) + + dequeue_thread = self.checkedThread(target=dequeue) + dequeue_thread.start() + # The close_op should run after the dequeue_thread has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + close_op.run() + dequeue_thread.join() + + def testDequeueUpToFromClosedQueueReturnsRemainder(self): + with self.test_session() as sess: + q = tf.FIFOQueue(10, tf.float32, ()) + elems = [10.0, 20.0, 30.0, 40.0] + enqueue_op = q.enqueue_many((elems,)) + close_op = q.close() + dequeued_t = q.dequeue_up_to(3) + + enqueue_op.run() + + def dequeue(): + self.assertAllEqual(elems[:3], sess.run(dequeued_t)) + self.assertAllEqual(elems[3:], sess.run(dequeued_t)) + + dequeue_thread = self.checkedThread(target=dequeue) + dequeue_thread.start() + # The close_op should run after the dequeue_thread has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + close_op.run() + dequeue_thread.join() + def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self): with self.test_session() as sess: q = tf.FIFOQueue(4, tf.float32, ()) @@ -788,7 +929,27 @@ class FIFOQueueTest(tf.test.TestCase): def dequeue(): # Expect the operation to fail due to the queue being closed. with self.assertRaisesRegexp(tf.errors.OutOfRangeError, - "is closed and has insufficient"): + "is closed and has insufficient"): + sess.run(dequeued_t) + + dequeue_thread = self.checkedThread(target=dequeue) + dequeue_thread.start() + # The close_op should run after the dequeue_thread has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + close_op.run() + dequeue_thread.join() + + def testBlockingDequeueUpToFromClosedEmptyQueue(self): + with self.test_session() as sess: + q = tf.FIFOQueue(10, tf.float32, ()) + close_op = q.close() + dequeued_t = q.dequeue_up_to(4) + + def dequeue(): + # Expect the operation to fail due to the queue being closed. + with self.assertRaisesRegexp(tf.errors.OutOfRangeError, + "is closed and has insufficient"): sess.run(dequeued_t) dequeue_thread = self.checkedThread(target=dequeue) @@ -1341,6 +1502,8 @@ class FIFOQueueWithTimeoutTest(tf.test.TestCase): with self.test_session( config=tf.ConfigProto(operation_timeout_in_ms=20)) as sess: q = tf.FIFOQueue(10, tf.float32) + self.assertEqual(tf.compat.as_bytes(""), + q.queue_ref.op.get_attr("container")) dequeued_t = q.dequeue() # Intentionally do not run any enqueue_ops so that dequeue will block @@ -1350,5 +1513,15 @@ class FIFOQueueWithTimeoutTest(tf.test.TestCase): sess.run(dequeued_t) +class QueueContainerTest(tf.test.TestCase): + + def testContainer(self): + with tf.Graph().as_default(): + with tf.container("test"): + q = tf.FIFOQueue(10, tf.float32) + self.assertEqual(tf.compat.as_bytes("test"), + q.queue_ref.op.get_attr("container")) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/pack_op_test.py b/tensorflow/python/kernel_tests/pack_op_test.py index 349def7181..5d7055824d 100644 --- a/tensorflow/python/kernel_tests/pack_op_test.py +++ b/tensorflow/python/kernel_tests/pack_op_test.py @@ -22,6 +22,14 @@ import numpy as np import tensorflow as tf +def np_split_sqeeze(array, axis): + axis_len = array.shape[axis] + return [ + np.squeeze(arr, axis=(axis,)) + for arr in np.split(array, axis_len, axis=axis) + ] + + class PackOpTest(tf.test.TestCase): def testSimple(self): @@ -61,7 +69,7 @@ class PackOpTest(tf.test.TestCase): b = tf.reshape(a, tf.pack([2, 3])) self.assertAllEqual(b.get_shape(), [2, 3]) - def testGradients(self): + def testGradientsAxis0(self): np.random.seed(7) for use_gpu in False, True: for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): @@ -74,6 +82,21 @@ class PackOpTest(tf.test.TestCase): err = tf.test.compute_gradient_error(xs, shapes, c, shape) self.assertLess(err, 1e-6) + def testGradientsAxis1(self): + np.random.seed(7) + for use_gpu in False, True: + for shape in (2, 3), (3, 2), (4, 3, 2): + data = np.random.randn(*shape) + shapes = [shape[1:]] * shape[0] + out_shape = list(shape[1:]) + out_shape.insert(1, shape[0]) + with self.test_session(use_gpu=use_gpu): + # TODO(irving): Remove list() once we handle maps correctly + xs = list(map(tf.constant, data)) + c = tf.pack(xs, axis=1) + err = tf.test.compute_gradient_error(xs, shapes, c, out_shape) + self.assertLess(err, 1e-6) + def testZeroSize(self): # Verify that pack doesn't crash for zero size inputs for use_gpu in False, True: @@ -83,6 +106,38 @@ class PackOpTest(tf.test.TestCase): p = tf.pack(list(x)).eval() self.assertAllEqual(p, x) + def testAxis0Default(self): + with self.test_session(): + t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])] + + packed = tf.pack(t).eval() + + self.assertAllEqual(packed, np.array([[1, 2, 3], [4, 5, 6]])) + + def testAgainstNumpy(self): + # For 1 to 5 dimensions. + for i in range(1, 6): + expected = np.random.random(np.random.permutation(i) + 1) + + # For all the possible axis to split it, including negative indices. + for j in range(-i, i): + test_arrays = np_split_sqeeze(expected, j) + + with self.test_session(): + actual = tf.pack(test_arrays, axis=j).eval() + + self.assertNDArrayNear(expected, actual, 1e-6) + + def testDimOutOfRange(self): + t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])] + with self.assertRaisesRegexp(ValueError, r"axis = 2 not in \[-2, 2\)"): + tf.unpack(t, axis=2) + + def testDimOutOfNegativeRange(self): + t = [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])] + with self.assertRaisesRegexp(ValueError, r"axis = -3 not in \[-2, 2\)"): + tf.unpack(t, axis=-3) + class AutomaticPackingTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py index 8e9449587a..ad2a52cd43 100644 --- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py @@ -276,6 +276,16 @@ class PaddingFIFOQueueTest(tf.test.TestCase): enqueue_op.run() self.assertEqual([], dequeued_t.eval().tolist()) + def testEmptyDequeueUpToWithDynamicShape(self): + with self.test_session(): + q = tf.PaddingFIFOQueue(10, tf.float32, shapes=((None,),)) + enqueue_op = q.enqueue(([10.0],)) + dequeued_t = q.dequeue_up_to(0) + + self.assertEqual([], dequeued_t.eval().tolist()) + enqueue_op.run() + self.assertEqual([], dequeued_t.eval().tolist()) + def testConstructPaddingFIFOQueueWithNoShape(self): with self.test_session(): with self.assertRaisesRegexp( @@ -328,6 +338,18 @@ class PaddingFIFOQueueTest(tf.test.TestCase): self.assertAllEqual(elems[0:4], dequeued_t.eval()) self.assertAllEqual(elems[4:8], dequeued_t.eval()) + def testDequeueUpToNoBlocking(self): + with self.test_session(): + q = tf.PaddingFIFOQueue(10, tf.float32, ((),)) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_op = q.enqueue_many((elems,)) + dequeued_t = q.dequeue_up_to(4) + + enqueue_op.run() + + self.assertAllEqual(elems[0:4], dequeued_t.eval()) + self.assertAllEqual(elems[4:8], dequeued_t.eval()) + def testMultiDequeueMany(self): with self.test_session() as sess: q = tf.PaddingFIFOQueue(10, (tf.float32, tf.int32), @@ -451,6 +473,62 @@ class PaddingFIFOQueueTest(tf.test.TestCase): tf.TensorShape(int_val.shape).is_compatible_with( dequeued_single_t[1].get_shape())) + def testMultiDequeueUpToPartiallyKnownShapesAndVariableInputNoBlocking(self): + with self.test_session() as sess: + q = tf.PaddingFIFOQueue(10, (tf.string, tf.int32), + shapes=((None,), (1, None))) + str_elems = [ + ["a"], + ["ab"], + ["abc"], + ["abc", "d"], + ["abc", "d", "e"], + ["abc", "d", "e", "f"]] + + int_elems = [ + [[1]], + [[2]], + [[3]], + [[1, 2]], + [[1, 2, 3]], + [[1, 2, 3, 4]]] + + enqueue_ops = [q.enqueue((str_elems[i], int_elems[i])) for i in range(6)] + + dequeued_t = q.dequeue_up_to(5) + dequeued_single_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + string_val, int_val = sess.run(dequeued_t) + + self.assertAllEqual( + [[b"a", b"", b""], [b"ab", b"", b""], [b"abc", b"", b""], + [b"abc", b"d", b""], [b"abc", b"d", b"e"]], string_val) + self.assertAllEqual( + [[[1, 0, 0]], + [[2, 0, 0]], + [[3, 0, 0]], + [[1, 2, 0]], + [[1, 2, 3]]], + int_val) + self.assertTrue( + tf.TensorShape(string_val.shape).is_compatible_with( + dequeued_t[0].get_shape())) + self.assertTrue( + tf.TensorShape(int_val.shape).is_compatible_with( + dequeued_t[1].get_shape())) + + string_val, int_val = sess.run(dequeued_single_t) + self.assertAllEqual([b"abc", b"d", b"e", b"f"], string_val) + self.assertAllEqual([[1, 2, 3, 4]], int_val) + self.assertTrue( + tf.TensorShape(string_val.shape).is_compatible_with( + dequeued_single_t[0].get_shape())) + self.assertTrue( + tf.TensorShape(int_val.shape).is_compatible_with( + dequeued_single_t[1].get_shape())) + def testHighDimension(self): with self.test_session(): q = tf.PaddingFIFOQueue(10, tf.int32, ((4, 4, 4, 4),)) @@ -576,6 +654,29 @@ class PaddingFIFOQueueTest(tf.test.TestCase): thread.join() self.assertItemsEqual(elems, dequeued_elems) + def testParallelDequeueUpTo(self): + with self.test_session() as sess: + q = tf.PaddingFIFOQueue(1000, tf.float32, shapes=((),)) + elems = [10.0 * x for x in range(1000)] + enqueue_op = q.enqueue_many((elems,)) + close_op = q.close() + dequeued_t = q.dequeue_up_to(101) + + enqueue_op.run() + close_op.run() + + # Dequeue up to 101 items in parallel on 10 threads, from closed queue. + dequeued_elems = [] + + def dequeue(): + dequeued_elems.extend(sess.run(dequeued_t)) + threads = [self.checkedThread(target=dequeue) for _ in range(10)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertItemsEqual(elems, dequeued_elems) + def testParallelEnqueueAndDequeue(self): with self.test_session() as sess: q = tf.PaddingFIFOQueue(50, tf.float32, shapes=((),)) @@ -703,6 +804,33 @@ class PaddingFIFOQueueTest(tf.test.TestCase): self.assertAllEqual(elems, dequeued_elems) + def testBlockingDequeueUpTo(self): + with self.test_session() as sess: + q = tf.PaddingFIFOQueue(10, tf.float32, ((),)) + elems = [10.0, 20.0, 30.0, 40.0] + enqueue_op = q.enqueue_many((elems,)) + dequeued_t = q.dequeue_up_to(4) + + dequeued_elems = [] + + def enqueue(): + # The enqueue_op should run after the dequeue op has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + sess.run(enqueue_op) + + def dequeue(): + dequeued_elems.extend(sess.run(dequeued_t).tolist()) + + enqueue_thread = self.checkedThread(target=enqueue) + dequeue_thread = self.checkedThread(target=dequeue) + enqueue_thread.start() + dequeue_thread.start() + enqueue_thread.join() + dequeue_thread.join() + + self.assertAllEqual(elems, dequeued_elems) + def testDequeueManyWithTensorParameter(self): with self.test_session(): # Define a first queue that contains integer counts. @@ -772,6 +900,28 @@ class PaddingFIFOQueueTest(tf.test.TestCase): close_op.run() dequeue_thread.join() + def testDequeueUpToFromClosedQueueReturnsRemainder(self): + with self.test_session() as sess: + q = tf.PaddingFIFOQueue(10, tf.float32, ((),)) + elems = [10.0, 20.0, 30.0, 40.0] + enqueue_op = q.enqueue_many((elems,)) + close_op = q.close() + dequeued_t = q.dequeue_up_to(3) + + enqueue_op.run() + + def dequeue(): + self.assertAllEqual(elems[:3], sess.run(dequeued_t)) + self.assertAllEqual(elems[3:], sess.run(dequeued_t)) + + dequeue_thread = self.checkedThread(target=dequeue) + dequeue_thread.start() + # The close_op should run after the dequeue_thread has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + close_op.run() + dequeue_thread.join() + def testBlockingDequeueFromClosedEmptyQueue(self): with self.test_session() as sess: q = tf.PaddingFIFOQueue(10, tf.float32, ((),)) @@ -817,6 +967,31 @@ class PaddingFIFOQueueTest(tf.test.TestCase): close_op.run() dequeue_thread.join() + def testBlockingDequeueManyButNotAllFromClosedQueue(self): + with self.test_session() as sess: + q = tf.PaddingFIFOQueue(10, tf.float32, ((),)) + elems = [10.0, 20.0, 30.0, 40.0] + enqueue_op = q.enqueue_many((elems,)) + close_op = q.close() + dequeued_t = q.dequeue_many(3) + + enqueue_op.run() + + def dequeue(): + self.assertAllEqual(elems[:3], sess.run(dequeued_t)) + # Expect the operation to fail due to the queue being closed. + with self.assertRaisesRegexp(tf.errors.OutOfRangeError, + "is closed and has insufficient"): + sess.run(dequeued_t) + + dequeue_thread = self.checkedThread(target=dequeue) + dequeue_thread.start() + # The close_op should run after the dequeue_thread has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + close_op.run() + dequeue_thread.join() + def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self): with self.test_session() as sess: q = tf.PaddingFIFOQueue(4, tf.float32, ((),)) @@ -906,6 +1081,26 @@ class PaddingFIFOQueueTest(tf.test.TestCase): close_op.run() dequeue_thread.join() + def testBlockingDequeueUpToFromClosedEmptyQueue(self): + with self.test_session() as sess: + q = tf.PaddingFIFOQueue(10, tf.float32, ((),)) + close_op = q.close() + dequeued_t = q.dequeue_up_to(4) + + def dequeue(): + # Expect the operation to fail due to the queue being closed. + with self.assertRaisesRegexp(tf.errors.OutOfRangeError, + "is closed and has insufficient"): + sess.run(dequeued_t) + + dequeue_thread = self.checkedThread(target=dequeue) + dequeue_thread.start() + # The close_op should run after the dequeue_thread has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + close_op.run() + dequeue_thread.join() + def testEnqueueToClosedQueue(self): with self.test_session(): q = tf.PaddingFIFOQueue(10, tf.float32, ((),)) diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py index 0487621b46..a68f722244 100644 --- a/tensorflow/python/kernel_tests/reshape_op_test.py +++ b/tensorflow/python/kernel_tests/reshape_op_test.py @@ -99,11 +99,6 @@ class ReshapeTest(tf.test.TestCase): self._testBothReshape(x, [1, -1, 5]) def testErrors(self): - x = tf.constant(0.0, shape=[1, 0, 3]) - with self.assertRaisesRegexp( - ValueError, "cannot infer the missing input size"): - tf.reshape(x, [0, -1, 5]) - y = tf.constant(0.0, shape=[23, 29, 31]) with self.assertRaisesRegexp(ValueError, "isn't divisible by 17"): tf.reshape(y, [17, -1]) @@ -128,6 +123,20 @@ class ReshapeTest(tf.test.TestCase): y = tf.reshape(x, tf.placeholder(tf.int32, shape=(3,))) self.assertEqual([None, None, None], y.get_shape().as_list()) + # Unknown input shape, partial new shape using `tf.pack()`. + y = tf.reshape(x, [tf.placeholder(tf.int32), 37]) + self.assertEqual([None, 37], y.get_shape().as_list()) + + # Unknown input shape, partial new shape using `tf.concat()`. + y = tf.reshape(x, tf.concat(0, [tf.placeholder(tf.int32, shape=(2,)), + [37, 42]])) + self.assertEqual([None, None, 37, 42], y.get_shape().as_list()) + + # Unknown input shape, partial new shape using `tf.shape()`. + y = tf.reshape(x, tf.shape(tf.placeholder(tf.float32, + shape=[None, 37, None]))) + self.assertEqual([None, 37, None], y.get_shape().as_list()) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 34b1c81e77..707133295c 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools import time import timeit @@ -1157,6 +1158,114 @@ class BidirectionalRNNTest(tf.test.TestCase): self._testBidirectionalRNNWithoutSequenceLength(use_gpu=True, use_shape=True) + def _createBidirectionalDynamicRNN(self, use_gpu, use_shape, + use_state_tuple, use_time_major): + num_units = 3 + input_size = 5 + batch_size = 2 + max_length = 8 + + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) + sequence_length = tf.placeholder(tf.int64) + cell_fw = tf.nn.rnn_cell.LSTMCell(num_units, + initializer=initializer, + state_is_tuple=use_state_tuple) + cell_bw = tf.nn.rnn_cell.LSTMCell(num_units, + initializer=initializer, + state_is_tuple=use_state_tuple) + inputs = max_length * [ + tf.placeholder(tf.float32, + shape=(batch_size if use_shape else None, input_size))] + inputs_c = tf.pack(inputs) + if not use_time_major: + inputs_c = tf.transpose(inputs_c, [1, 0, 2]) + outputs, states = tf.nn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs_c, + sequence_length, + dtype=tf.float32, + time_major=use_time_major) + outputs = tf.concat(2, outputs) + state_fw, state_bw = states + outputs_shape = [None, max_length, 2 * num_units] + if use_shape: + outputs_shape[0] = batch_size + if use_time_major: + outputs_shape[0], outputs_shape[1] = outputs_shape[1], outputs_shape[0] + self.assertEqual( + outputs.get_shape().as_list(), + outputs_shape) + + input_value = np.random.randn(batch_size, input_size) + + return input_value, inputs, outputs, state_fw, state_bw, sequence_length + + def _testBidirectionalDynamicRNN(self, use_gpu, use_shape, + use_state_tuple, use_time_major): + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + input_value, inputs, outputs, state_fw, state_bw, sequence_length = ( + self._createBidirectionalDynamicRNN( + use_gpu, use_shape, use_state_tuple, use_time_major)) + tf.initialize_all_variables().run() + # Run with pre-specified sequence length of 2, 3 + if use_state_tuple: + out, c_fw, m_fw, c_bw, m_bw = sess.run( + [outputs, state_fw[0], state_fw[1], state_bw[0], state_bw[1]], + feed_dict={inputs[0]: input_value, + sequence_length: [2, 3]}) + s_fw = (c_fw, m_fw) + s_bw = (c_bw, m_bw) + else: + out, s_fw, s_bw = sess.run([outputs, state_fw, state_bw], + feed_dict={inputs[0]: input_value, + sequence_length: [2, 3]}) + + # Since the forward and backward LSTM cells were initialized with the + # same parameters, the forward and backward output has to be the same, + # but reversed in time. The format is output[time][batch][depth], and + # due to depth concatenation (as num_units=3 for both RNNs): + # - forward output: out[][][depth] for 0 <= depth < 3 + # - backward output: out[][][depth] for 4 <= depth < 6 + # + # First sequence in batch is length=2 + # Check that the time=0 forward output is equal to time=1 backward output + if not use_time_major: + out = np.swapaxes(out, 0, 1) + self.assertEqual(out[0][0][0], out[1][0][3]) + self.assertEqual(out[0][0][1], out[1][0][4]) + self.assertEqual(out[0][0][2], out[1][0][5]) + # Check that the time=1 forward output is equal to time=0 backward output + self.assertEqual(out[1][0][0], out[0][0][3]) + self.assertEqual(out[1][0][1], out[0][0][4]) + self.assertEqual(out[1][0][2], out[0][0][5]) + + # Second sequence in batch is length=3 + # Check that the time=0 forward output is equal to time=2 backward output + self.assertEqual(out[0][1][0], out[2][1][3]) + self.assertEqual(out[0][1][1], out[2][1][4]) + self.assertEqual(out[0][1][2], out[2][1][5]) + # Check that the time=1 forward output is equal to time=1 backward output + self.assertEqual(out[1][1][0], out[1][1][3]) + self.assertEqual(out[1][1][1], out[1][1][4]) + self.assertEqual(out[1][1][2], out[1][1][5]) + # Check that the time=2 forward output is equal to time=0 backward output + self.assertEqual(out[2][1][0], out[0][1][3]) + self.assertEqual(out[2][1][1], out[0][1][4]) + self.assertEqual(out[2][1][2], out[0][1][5]) + # Via the reasoning above, the forward and backward final state should be + # exactly the same + self.assertAllClose(s_fw, s_bw) + + def testBidirectionalDynamicRNN(self): + # Generate 2^4 option values + # from [True, True, True, True] to [False, False, False, False] + options = itertools.product([True, False], repeat=4) + for option in options: + self._testBidirectionalDynamicRNN(use_gpu=option[0], use_shape=option[1], + use_state_tuple=option[2], + use_time_major=option[3]) + class MultiDimensionalLSTMTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py index 7cc6d31efb..0cb701db82 100644 --- a/tensorflow/python/kernel_tests/unpack_op_test.py +++ b/tensorflow/python/kernel_tests/unpack_op_test.py @@ -23,6 +23,14 @@ from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +def np_split_sqeeze(array, axis): + axis_len = array.shape[axis] + return [ + np.squeeze(arr, axis=(axis,)) + for arr in np.split(array, axis_len, axis=axis) + ] + + class UnpackOpTest(tf.test.TestCase): def testSimple(self): @@ -40,7 +48,7 @@ class UnpackOpTest(tf.test.TestCase): cs = [c.eval() for c in cs] self.assertAllEqual(cs, data) - def testGradients(self): + def testGradientsAxis0(self): for use_gpu in False, True: for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): data = np.random.randn(*shape) @@ -52,6 +60,19 @@ class UnpackOpTest(tf.test.TestCase): err = tf.test.compute_gradient_error(x, shape, cs[i], shapes[i]) self.assertLess(err, 1e-6) + def testGradientsAxis1(self): + for use_gpu in False, True: + for shape in (2, 3), (3, 2), (4, 3, 2): + data = np.random.randn(*shape) + out_shape = list(shape) + del out_shape[1] + for i in xrange(shape[1]): + with self.test_session(use_gpu=use_gpu): + x = tf.constant(data) + cs = tf.unpack(x, num=shape[1], axis=1) + err = tf.test.compute_gradient_error(x, shape, cs[i], out_shape) + self.assertLess(err, 1e-6) + def testInferNum(self): with self.test_session(): for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): @@ -60,12 +81,55 @@ class UnpackOpTest(tf.test.TestCase): self.assertEqual(type(cs), list) self.assertEqual(len(cs), shape[0]) - def testCannotInferNum(self): + def testCannotInferNumFromUnknownShape(self): x = tf.placeholder(np.float32) with self.assertRaisesRegexp( ValueError, r'Cannot infer num from shape <unknown>'): tf.unpack(x) + def testUnknownShapeOkWithNum(self): + x = tf.placeholder(np.float32) + tf.unpack(x, num=2) + + def testCannotInferNumFromNoneShape(self): + x = tf.placeholder(np.float32, shape=(None,)) + with self.assertRaisesRegexp(ValueError, + r'Cannot infer num from shape \(\?,\)'): + tf.unpack(x) + + def testAgainstNumpy(self): + # For 1 to 5 dimensions. + for i in range(1, 6): + a = np.random.random(np.random.permutation(i) + 1) + + # For all the possible axis to split it, including negative indices. + for j in range(-i, i): + expected = np_split_sqeeze(a, j) + + with self.test_session() as sess: + actual = sess.run(tf.unpack(a, axis=j)) + + self.assertAllEqual(expected, actual) + + def testAxis0Default(self): + with self.test_session() as sess: + a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a') + + unpacked = sess.run(tf.unpack(a)) + + self.assertEqual(len(unpacked), 2) + self.assertAllEqual(unpacked[0], [1, 2, 3]) + self.assertAllEqual(unpacked[1], [4, 5, 6]) + + def testAxisOutOfRange(self): + a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a') + with self.assertRaisesRegexp(ValueError, r'axis = 2 not in \[-2, 2\)'): + tf.unpack(a, axis=2) + + def testAxisOutOfNegativeRange(self): + a = tf.constant([[1, 2, 3], [4, 5, 6]], name='a') + with self.assertRaisesRegexp(ValueError, r'axis = -3 not in \[-2, 2\)'): + tf.unpack(a, axis=-3) if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 5b70f94723..82f010cf5b 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -25,6 +25,7 @@ import tensorflow as tf from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops class VariablesTestCase(tf.test.TestCase): @@ -418,5 +419,26 @@ class ObsoleteIsInitializedTest(tf.test.TestCase): inited.op.run() +class VariableContainerTest(tf.test.TestCase): + + def testContainer(self): + with tf.Graph().as_default(): + v0 = tf.Variable([0]) + with tf.container("l1"): + v1 = tf.Variable([1]) + with tf.container("l2"): + v2 = tf.Variable([2]) + special_v = state_ops.variable_op([1], tf.float32, container="l3") + v3 = tf.Variable([3]) + v4 = tf.Variable([4]) + self.assertEqual(tf.compat.as_bytes(""), v0.op.get_attr("container")) + self.assertEqual(tf.compat.as_bytes("l1"), v1.op.get_attr("container")) + self.assertEqual(tf.compat.as_bytes("l2"), v2.op.get_attr("container")) + self.assertEqual(tf.compat.as_bytes("l3"), + special_v.op.get_attr("container")) + self.assertEqual(tf.compat.as_bytes("l1"), v3.op.get_attr("container")) + self.assertEqual(tf.compat.as_bytes(""), v4.op.get_attr("container")) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 0e85aaf80f..7ccef03b1a 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -30,13 +30,13 @@ from tensorflow.python.ops import math_ops @ops.RegisterGradient("Pack") def _PackGrad(op, grad): """Gradient for pack op.""" - return array_ops.unpack(grad, num=op.get_attr("N")) + return array_ops.unpack(grad, num=op.get_attr("N"), axis=op.get_attr("axis")) @ops.RegisterGradient("Unpack") -def _UnpackGrad(_, *grads): +def _UnpackGrad(op, *grads): """Gradient for unpack op.""" - return array_ops.pack(grads) + return array_ops.pack(grads, axis=op.get_attr("axis")) @ops.RegisterGradient("Concat") @@ -149,6 +149,27 @@ def _SliceGrad(op, grad): return array_ops.pad(grad, paddings), None, None +@ops.RegisterGradient("StridedSlice") +def _StridedSliceGrad(op, grad): + """Gradient for unpack op.""" + x = array_ops.shape(op.inputs[0]) + begin = op.inputs[1] + end = op.inputs[2] + strides = op.inputs[3] + + return array_ops.strided_slice_grad( + x, + begin, + end, + strides, + grad, + begin_mask=op.get_attr("begin_mask"), + end_mask=op.get_attr("end_mask"), + ellipse_mask=op.get_attr("ellipse_mask"), + new_axis_mask=op.get_attr("new_axis_mask"), + shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None + + @ops.RegisterGradient("Split") def _SplitGrad(op, *grads): return None, array_ops.concat(op.inputs[0], list(grads)) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 70b4c5bd35..509f627170 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -495,7 +495,7 @@ def strided_slice(input_, ops.Tensor._override_operator("__getitem__", _SliceHelper) -def pack(values, name="pack"): +def pack(values, axis=0, name="pack"): """Packs a list of rank-`R` tensors into one rank-`(R+1)` tensor. Packs tensors in `values` into a tensor with rank one higher than each tensor @@ -508,17 +508,31 @@ def pack(values, name="pack"): Args: values: A list of `Tensor` objects with the same shape and type. + axis: An `int`. The axis to pack along. Defaults to the first dimension. + Supports negative indexes. name: A name for this operation (optional). Returns: output: A packed `Tensor` with the same type as `values`. + + Raises: + ValueError: If `axis` is out of the range [-(R+1), R+1). """ - try: - # If the input is a constant list, it can just be converted to a constant op - return ops.convert_to_tensor(values, name=name) - except (TypeError, ValueError): - # Input list contains non-constant tensors - return gen_array_ops._pack(values, name=name) + if axis == 0: + try: + # If the input is a constant list, it can be converted to a constant op + return ops.convert_to_tensor(values, name=name) + except (TypeError, ValueError): + pass # Input list contains non-constant tensors + + value_shape = ops.convert_to_tensor(values[0], name=name).get_shape() + if value_shape.ndims is not None: + expanded_num_dims = value_shape.ndims + 1 + if axis < -expanded_num_dims or axis >= expanded_num_dims: + raise ValueError("axis = %d not in [%d, %d)" % + (axis, -expanded_num_dims, expanded_num_dims)) + + return gen_array_ops._pack(values, axis=axis, name=name) # pylint: disable=invalid-name @@ -609,12 +623,12 @@ ops.register_tensor_conversion_function( (list, tuple), _autopacking_conversion_function, 99) -def unpack(value, num=None, name="unpack"): - """Unpacks the outer dimension of a rank-`R` tensor into rank-`(R-1)` tensors. +def unpack(value, num=None, axis=0, name="unpack"): + """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors. - Unpacks `num` tensors from `value` along the first dimension. + Unpacks `num` tensors from `value` along the given dimension. If `num` is not specified (the default), it is inferred from `value`'s shape. - If `value.shape[0]` is not known, `ValueError` is raised. + If `value.shape[axis]` is not known, `ValueError` is raised. The ith tensor in `output` is the slice `value[i, ...]`. Each tensor in `output` has shape `value.shape[1:]`. @@ -625,8 +639,10 @@ def unpack(value, num=None, name="unpack"): Args: value: A rank `R > 0` `Tensor` to be unpacked. - num: An `int`. The first dimension of value. Automatically inferred if - `None` (the default). + num: An `int`. The length of the dimension `axis`. Automatically inferred + if `None` (the default). + axis: An `int`. The axis to unpack along. Defaults to the first + dimension. Supports negative indexes. name: A name for the operation (optional). Returns: @@ -634,14 +650,19 @@ def unpack(value, num=None, name="unpack"): Raises: ValueError: If `num` is unspecified and cannot be inferred. + ValueError: If `axis` is out of the range [-R, R). """ if num is None: value = ops.convert_to_tensor(value) - shape = value.get_shape() - num = shape[0].value - if num is None: - raise ValueError("Cannot infer num from shape %s" % shape) - return gen_array_ops._unpack(value, num=num, name=name) + value_shape = value.get_shape() + if value_shape.ndims is not None: + if axis < -value_shape.ndims or axis >= value_shape.ndims: + raise ValueError("axis = %d not in [%d, %d)" % + (axis, -value_shape.ndims, value_shape.ndims)) + num = value_shape[axis].value + if num is None: + raise ValueError("Cannot infer num from shape %s" % value_shape) + return gen_array_ops._unpack(value, num=num, axis=axis, name=name) def concat(concat_dim, values, name="concat"): @@ -707,15 +728,26 @@ def concat(concat_dim, values, name="concat"): @ops.RegisterShape("Pack") def _PackShape(op): input_shape = op.inputs[0].get_shape() + if input_shape.ndims is None: + return [tensor_shape.unknown_shape()] + for inp in op.inputs[1:]: input_shape = input_shape.merge_with(inp.get_shape()) - return [tensor_shape.TensorShape([len(op.inputs)]).concatenate(input_shape)] + + input_shape = input_shape.as_list() + input_shape.insert(op.get_attr("axis"), len(op.inputs)) + return [tensor_shape.TensorShape(input_shape)] @ops.RegisterShape("Unpack") def _UnpackShape(op): input_shape = op.inputs[0].get_shape() - return [input_shape[1:]] * op.get_attr("num") + if input_shape.ndims is None: + return [tensor_shape.unknown_shape()] * op.get_attr("num") + + input_shape = input_shape.as_list() + del input_shape[op.get_attr("axis")] + return [tensor_shape.TensorShape(input_shape)] * op.get_attr("num") @ops.RegisterShape("Concat") @@ -1437,6 +1469,12 @@ def _compute_size_of_strided_dim(spec, size): return unknown # unknown because stride is unknown +@ops.RegisterShape("StridedSliceGrad") +def _StridedSliceGradShape(op): + """Shape function for gradient of array_ops.slice.""" + return [tensor_util.constant_value(op.inputs[0])] + + @ops.RegisterShape("StridedSlice") def _StridedSliceShape(op): """Shape function for array_ops.slice.""" @@ -1742,45 +1780,38 @@ def _ReshapeShape(op): num_elements *= dim else: num_elements = tensor_shape.Dimension(None) - new_shape_shape = op.inputs[1].get_shape().with_rank(1) - new_shape = tensor_util.constant_value(op.inputs[1]) - if new_shape is None: - # Attempt to infer the rank of the output from the length of - # new_shape. - return [tensor_shape.unknown_shape(ndims=new_shape_shape[0].value)] - new_shape = np.reshape(new_shape, -1).tolist() - if -1 not in new_shape: + new_shape = tensor_util.constant_value_as_shape(op.inputs[1]) + if new_shape.ndims is None: + # We have no information about the shape of the output. + return [new_shape] + if None not in new_shape.as_list(): # The new shape is fully defined. if (num_elements.value is not None and num_elements.value != np.prod(new_shape)): raise ValueError( "Cannot reshape a tensor with %d elements to shape %s (%d elements)" % (num_elements.value, new_shape, np.prod(new_shape))) - return [tensor_shape.TensorShape(new_shape)] elif num_elements.value is not None: # We know the number of elements, so we can calculate the missing # dimension in the new_shape. known_elements = 1 - unknown_index = None + unknown_indices = [] for i, dim in enumerate(new_shape): - if dim == -1: - unknown_index = i + if dim.value is None: + unknown_indices.append(i) else: - known_elements *= dim - if known_elements == 0: - raise ValueError("cannot infer the missing input size for " - "an empty tensor unless all specified " - "input sizes are non-zero") - if num_elements % known_elements != 0: - raise ValueError("input has %s elements, which isn't divisible by %d" % - (num_elements, known_elements)) - new_shape[unknown_index] = num_elements // known_elements - return [tensor_shape.TensorShape(new_shape)] - else: - # We don't know the input shape, but we know n-1 of the dimensions - # in the new shape. - new_shape[new_shape.index(-1)] = None - return [tensor_shape.TensorShape(new_shape)] + known_elements *= dim.value + if known_elements != 0: + if num_elements % known_elements != 0: + raise ValueError("input has %s elements, which isn't divisible by %d" % + (num_elements, known_elements)) + if len(unknown_indices) == 1: + unknown_index = unknown_indices[0] + new_shape = new_shape.merge_with( + new_shape[:unknown_index].concatenate( + [num_elements // known_elements]).concatenate( + new_shape[unknown_index+1:])) + return [new_shape] @ops.RegisterShape("BroadcastGradientArgs") diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py index 24d4f4a9a2..a977bb57e0 100644 --- a/tensorflow/python/ops/control_flow_grad.py +++ b/tensorflow/python/ops/control_flow_grad.py @@ -199,7 +199,7 @@ def _EnterGrad(op, grad): if op.get_attr("is_constant"): # Add a gradient accumulator for each loop invariant. if isinstance(grad, ops.Tensor): - result = grad_ctxt.AddBackPropAccumulator(grad) + result = grad_ctxt.AddBackPropAccumulator(op, grad) elif isinstance(grad, ops.IndexedSlices): result = grad_ctxt.AddBackPropIndexedSlicesAccumulator(grad) else: diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 3063675e40..8c02399f81 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -773,13 +773,13 @@ class GradLoopState(object): # Record the history of this value in forward_ctxt. # TODO(yuanbyu): Avoid recording constants. self._grad_context.Exit() - h_value = cur_grad_state.AddForwardAccumulator(cur_value) + history_value = cur_grad_state.AddForwardAccumulator(cur_value) self._grad_context.Enter() break if real_value is None: # Add the stack pop op in the grad context. - real_value = self.AddBackPropAccumulatedValue(h_value, value) + real_value = self.AddBackPropAccumulatedValue(history_value, value) self._history_map[value.name] = real_value return real_value @@ -966,13 +966,13 @@ class ControlFlowState(object): # Add forward accumulator for shape. grad_state.grad_context.Exit() - h_shape = grad_state.AddForwardAccumulator( + history_zeros_shape = grad_state.AddForwardAccumulator( zeros_shape, dead_branch=dead_branch) grad_state.grad_context.Enter() # Create a zero tensor with the right shape. shape = grad_state.AddBackPropAccumulatedValue( - h_shape, zeros_shape, dead_branch) + history_zeros_shape, zeros_shape, dead_branch) result = array_ops.zeros(shape, val.dtype) return result @@ -1596,36 +1596,56 @@ class WhileContext(ControlFlowContext): self.Exit() return next_count - def AddBackPropAccumulator(self, value): + def AddBackPropAccumulator(self, op, grad): """Add an accumulation loop for every loop invariant. - This is added to the backprop loop. It is used to accumulate - partial gradients within each loop iteration. Called when in the - gradient while context. + This is added to the backprop loop. It is used to accumulate partial + gradients within each loop iteration. Called when in the gradient while + context. The pseudocode is: ``` acc = 0.0; while (_pivot) { - acc += value; + acc += grad; } ``` Args: - value: The partial gradient of an iteration for a loop invariant. + op: The Enter op for a loop invariant. + grad: The partial gradient of an iteration for a loop invariant. Returns: The gradient for a loop invariant. """ self.Exit() - shape = value.get_shape() - if not shape.is_fully_defined(): - shape = None - if self.outer_context: self.outer_context.Enter() - acc = constant_op.constant(0, value.dtype, shape=shape, name="b_acc") - if not shape: - acc._shape = value.get_shape() # pylint: disable=protected-access - if self.outer_context: self.outer_context.Exit() + # Create a zeros tensor with the right shape for acc. If we don't + # know the full shape statically, we will have to get the shape + # dynamically from the forward inference. Getting the shape right + # for the zeros is only needed for the base case when the loop exits + # without running any iterations. + shape = grad.get_shape() + if shape.is_fully_defined(): + if self.outer_context: self.outer_context.Enter() + acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc") + if self.outer_context: self.outer_context.Exit() + else: + value = op.inputs[0] + if self.outer_context: + forward_ctxt = self.grad_state.forward_ctxt + forward_ctxt.outer_context.Enter() + zeros_shape = array_ops.shape(value) + forward_ctxt.outer_context.Exit() + history_zeros_shape = grad_state.AddForwardAccumulator(zeros_shape) + self.outer_context.Enter() + real_shape = outer_grad_state.AddBackPropAccumulatedValue( + history_zeros_shape, zeros_shape) + acc = array_ops.zeros(real_shape, grad.dtype) + self.outer_context.Exit() + else: + zeros_shape = array_ops.shape(value) + acc = array_ops.zeros(zeros_shape, grad.dtype) + acc._shape = grad.get_shape() # pylint: disable=protected-access self.Enter() self.AddName(acc.name) @@ -1633,30 +1653,30 @@ class WhileContext(ControlFlowContext): parallel_iterations=self._parallel_iterations, name="b_acc") merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0] - switch_acc = switch(merge_acc, self._pivot) + switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot) - add_acc = math_ops.add(switch_acc[1], value) + add_acc = math_ops.add(switch_acc_true, grad) next_acc = _NextIteration(add_acc) merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access - acc_result = exit(switch_acc[0], name="b_acc") + acc_result = exit(switch_acc_false, name="b_acc") self.ExitResult([acc_result]) return acc_result - def AddBackPropIndexedSlicesAccumulator(self, value): + def AddBackPropIndexedSlicesAccumulator(self, grad): """This is used for accumulating gradients that are IndexedSlices. This is essentially the equavalent of AddBackPropAccumulator but optimized for things like updating embeddings from within a while loop. Args: - value: The partial gradients represented as an IndexedSlices. + grad: The partial gradients represented as an IndexedSlices. Returns: The accumulated IndexedSlices gradient of the loop invariant. """ - values = value.values - indices = value.indices + values = grad.values + indices = grad.indices self.Exit() shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] + @@ -1670,6 +1690,7 @@ class WhileContext(ControlFlowContext): values_acc._shape = shape # pylint: disable=protected-access indices_acc = constant_op.constant([0], indices.dtype) if self.outer_context: self.outer_context.Exit() + self.Enter() self.AddName(values_acc.name) self.AddName(indices_acc.name) @@ -1687,10 +1708,10 @@ class WhileContext(ControlFlowContext): for xm, xn in zip(merge_acc, next_acc): xm.op._update_input(1, xn) # pylint: disable=protected-access - acc_result = [exit(x[0], name="b_acc") for x in switch_acc] - self.ExitResult(acc_result) - return ops.IndexedSlices(values=acc_result[1], indices=acc_result[0], - dense_shape=self.ExitResult(value.dense_shape)) + acc_exits = [exit(x[0], name="b_acc") for x in switch_acc] + self.ExitResult(acc_exits) + return ops.IndexedSlices(values=acc_exits[1], indices=acc_exits[0], + dense_shape=grad.dense_shape) def _InitializeValues(self, values): self._values = set() @@ -1882,8 +1903,8 @@ def while_loop(cond, body, loop_vars, parallel_iterations=10, back_prop=True, ```python ijk_0 = (tf.constant(0), (tf.constant(1), tf.constant(2))) - c = lambda i, j, k: i < 10 - b = lambda i, j, k: (i, (j + k), (j - k)) + c = lambda i, (j, k): i < 10 + b = lambda i, (j, k): (i + 1, ((j + k), (j - k))) ijk_final = tf.while_loop(c, b, ijk_0) ``` """ diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 00033d4bc7..a4a3c4e669 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -456,7 +456,7 @@ class QueueBase(object): If the queue is closed and there are more than `0` but fewer than `n` elements remaining, then instead of raising a `tf.errors.OutOfRangeError` like [`dequeue_many`](#QueueBase.dequeue_many), - the remaining elements are returned immediately. If the queue is + less than `n` elements are returned immediately. If the queue is closed and there are `0` elements left in the queue, then a `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`. Otherwise the behavior is identical to `dequeue_many`. @@ -732,6 +732,7 @@ ops.NoGradient("HashTable") ops.NoGradient("InitializeTable") ops.NoGradient("InitializeTableFromTextFile") ops.NoGradient("MutableHashTable") +ops.NoGradient("MutableHashTableOfTensors") ops.RegisterShape("QueueSize")(common_shapes.scalar_shape) @@ -807,16 +808,13 @@ def _DynamicStitchShape(op): def _LookupTableFindShape(op): """Shape function for data_flow_ops._lookup_table_find.""" op.inputs[0].get_shape().merge_with(tensor_shape.scalar()) - shape_in = op.inputs[1].get_shape() - return [shape_in] + return [tensor_shape.unknown_shape()] @ops.RegisterShape("LookupTableInsert") def _LookupTableInsertShape(op): """Shape function for data_flow_ops._lookup_table_insert.""" op.inputs[0].get_shape().merge_with(tensor_shape.scalar()) - keys_shape = op.inputs[1].get_shape() - op.inputs[2].get_shape().merge_with(keys_shape) return [] @@ -827,8 +825,18 @@ def _LookupTableSizeShape(op): return [tensor_shape.scalar()] +@ops.RegisterShape("LookupTableExport") +def _LookupTableExportShape(op): + """Shape function for data_flow_ops._lookup_table_export_values.""" + op.inputs[0].get_shape().merge_with(tensor_shape.scalar()) + keys_shape = tensor_shape.vector(None) + values_shape = tensor_shape.unknown_shape() + return [keys_shape, values_shape] + + @ops.RegisterShape("HashTable") @ops.RegisterShape("MutableHashTable") +@ops.RegisterShape("MutableHashTableOfTensors") def _HashTableShape(_): """Shape function for data_flow_ops._hash_table.""" return [tensor_shape.scalar()] diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py index ceaf518bd4..55fe7b067c 100644 --- a/tensorflow/python/ops/image_grad.py +++ b/tensorflow/python/ops/image_grad.py @@ -22,6 +22,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_image_ops @@ -88,3 +89,65 @@ def _ResizeShape(op): def _ResizeBilinearGradShape(op): """Shape function for ResizeBilinearGrad.""" return [op.inputs[1].get_shape()] + + +@ops.RegisterShape("CropAndResizeGradImage") +def _CropAndResizeGradImageShape(op): + """Shape function for CropAndResizeGradImage.""" + image_size = tensor_util.constant_value(op.inputs[3]) + if image_size is not None: + batch = image_size[0] + height = image_size[1] + width = image_size[2] + depth = image_size[3] + else: + batch = None + height = None + width = None + depth = None + return [tensor_shape.TensorShape([batch, height, width, depth])] + + +@ops.RegisterShape("CropAndResizeGradBoxes") +def _CropAndResizeGradBoxesShape(op): + """Shape function for CropAndResizeGradBoxes.""" + return [op.inputs[2].get_shape()] + + +@ops.RegisterGradient("CropAndResize") +def _CropAndResizeGrad(op, grad): + """The derivatives for crop_and_resize. + + We back-propagate to the image only when the input image tensor has floating + point dtype but we always back-propagate to the input boxes tensor. + + Args: + op: The CropAndResize op. + grad: The tensor representing the gradient w.r.t. the output. + + Returns: + The gradients w.r.t. the input image, boxes, as well as the always-None + gradients w.r.t. box_ind and crop_size. + """ + image = op.inputs[0] + if image.get_shape().is_fully_defined(): + image_shape = image.get_shape().as_list() + else: + image_shape = array_ops.shape(image) + + allowed_types = [dtypes.float16, dtypes.float32, dtypes.float64] + if op.inputs[0].dtype in allowed_types: + # pylint: disable=protected-access + grad0 = gen_image_ops.crop_and_resize_grad_image(grad, + op.inputs[1], + op.inputs[2], + image_shape, + T=op.get_attr("T")) + # pylint: enable=protected-access + else: + grad0 = None + + grad1 = gen_image_ops.crop_and_resize_grad_boxes(grad, op.inputs[0], + op.inputs[1], op.inputs[2]) + + return [grad0, grad1, None, None] diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py index 9d9cdd4ed0..dab9619424 100644 --- a/tensorflow/python/ops/image_grad_test.py +++ b/tensorflow/python/ops/image_grad_test.py @@ -174,5 +174,134 @@ class ResizeBilinearOpTest(tf.test.TestCase): grad = tf.gradients(input_tensor, [resize_out]) self.assertEqual([None], grad) + +class CropAndResizeOpTest(tf.test.TestCase): + + def testShapeIsCorrectAfterOp(self): + batch = 2 + image_height = 3 + image_width = 4 + crop_height = 4 + crop_width = 5 + depth = 2 + num_boxes = 2 + + image_shape = [batch, image_height, image_width, depth] + crop_size = [crop_height, crop_width] + crops_shape = [num_boxes, crop_height, crop_width, depth] + + image = np.arange(0, batch * image_height * image_width * + depth).reshape(image_shape).astype(np.float32) + boxes = np.array([[0, 0, 1, 1], [.1, .2, .7, .8]], dtype=np.float32) + box_ind = np.array([0, 1], dtype=np.int32) + + for use_gpu in [False, True]: + with self.test_session(use_gpu=use_gpu) as sess: + crops = tf.image.crop_and_resize( + tf.constant(image, shape=image_shape), + tf.constant(boxes, shape=[num_boxes, 4]), + tf.constant(box_ind, shape=[num_boxes]), + tf.constant(crop_size, shape=[2])) + self.assertEqual(crops_shape, list(crops.get_shape())) + crops = sess.run(crops) + self.assertEqual(crops_shape, list(crops.shape)) + + def _randomUniformAvoidAnchors(self, low, high, anchors, radius, num_samples): + """Generate samples that are far enough from a set of anchor points. + + We generate uniform samples in [low, high], then reject those that are less + than radius away from any point in anchors. We stop after we have accepted + num_samples samples. + + Args: + low: The lower end of the interval. + high: The upper end of the interval. + anchors: A list of length num_crops with anchor points to avoid. + radius: Distance threshold for the samples from the anchors. + num_samples: How many samples to produce. + + Returns: + samples: A list of length num_samples with the accepted samples. + """ + self.assertTrue(low < high) + self.assertTrue(radius >= 0) + num_anchors = len(anchors) + # Make sure that at least half of the interval is not forbidden. + self.assertTrue(2 * radius * num_anchors < 0.5 * (high - low)) + anchors = np.reshape(anchors, num_anchors) + samples = [] + while len(samples) < num_samples: + sample = np.random.uniform(low, high) + if np.all(np.fabs(sample - anchors) > radius): + samples.append(sample) + return samples + + def testGradRandomBoxes(self): + """Test that the gradient is correct for randomly generated boxes. + + The mapping is piecewise differentiable with respect to the box coordinates. + The points where the function is not differentiable are those which are + mapped to image pixels, i.e., the normalized y coordinates in + np.linspace(0, 1, image_height) and normalized x coordinates in + np.linspace(0, 1, image_width). Make sure that the box coordinates are + sufficiently far away from those rectangular grid centers that are points of + discontinuity, so that the finite difference Jacobian is close to the + computed one. + """ + np.random.seed(1) # Make it reproducible. + delta = 1e-3 + radius = 2 * delta + low, high = -0.5, 1.5 # Also covers the case of extrapolation. + + for image_height in range(1, 5): + for image_width in range(1, 3): + for crop_height in range(1, 3): + for crop_width in range(2, 4): + for depth in range(1, 3): + for num_boxes in range(1, 3): + + batch = num_boxes + image_shape = [batch, image_height, image_width, depth] + crop_size = [crop_height, crop_width] + crops_shape = [num_boxes, crop_height, crop_width, depth] + boxes_shape = [num_boxes, 4] + + image = np.arange(0, batch * image_height * image_width * + depth).reshape(image_shape).astype(np.float32) + boxes = [] + for _ in range(num_boxes): + # pylint: disable=unbalanced-tuple-unpacking + y1, y2 = self._randomUniformAvoidAnchors( + low, high, np.linspace(0, 1, image_height), radius, 2) + x1, x2 = self._randomUniformAvoidAnchors( + low, high, np.linspace(0, 1, image_width), radius, 2) + # pylint: enable=unbalanced-tuple-unpacking + boxes.append([y1, x1, y2, x2]) + + boxes = np.array(boxes, dtype=np.float32) + box_ind = np.arange(batch, dtype=np.int32) + + for use_gpu in [False, True]: + with self.test_session(use_gpu=use_gpu): + image_tensor = tf.constant(image, shape=image_shape) + boxes_tensor = tf.constant(boxes, shape=[num_boxes, 4]) + box_ind_tensor = tf.constant(box_ind, shape=[num_boxes]) + crops = tf.image.crop_and_resize( + image_tensor, + boxes_tensor, + box_ind_tensor, + tf.constant(crop_size, shape=[2])) + + err = tf.test.compute_gradient_error( + [image_tensor, boxes_tensor], + [image_shape, boxes_shape], + crops, + crops_shape, + delta=delta, + x_init_value=[image, boxes]) + + self.assertLess(err, 2e-3) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index 2df7b9c4c7..a72690c0f2 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -73,6 +73,8 @@ resized_image = tf.image.resize_images(image, 299, 299) @@crop_to_bounding_box @@extract_glimpse +@@crop_and_resize + ## Flipping and Transposing @@flip_up_down @@ -1327,8 +1329,8 @@ def _random_crop_shape(op): return [tensor_shape.TensorShape(output_shape)] -@ops.RegisterShape("ExtractGlimpse") -def _ExtractGlimpseShape(op): +@ops.RegisterShape('ExtractGlimpse') +def _extract_glimpse_shape(op): """Shape function for ExtractGlimpse op.""" input_shape = op.inputs[0].get_shape().with_rank(4) unused_size_shape = op.inputs[1].get_shape().merge_with( @@ -1347,6 +1349,22 @@ def _ExtractGlimpseShape(op): [input_shape[0], height, width, input_shape[3]])] +@ops.RegisterShape('CropAndResize') +def _crop_and_resize_shape(op): + """Shape function for the CropAndResize op.""" + image_shape = op.inputs[0].get_shape().with_rank(4) + box_shape = op.inputs[1].get_shape().with_rank(2) + crop_size = tensor_util.constant_value(op.inputs[3]) + if crop_size is not None: + crop_height = crop_size[0] + crop_width = crop_size[1] + else: + crop_height = None + crop_width = None + return [tensor_shape.TensorShape( + [box_shape[0], crop_height, crop_width, image_shape[3]])] + + __all__ = make_all(__name__) # ResizeMethod is not documented, but is documented in functions that use it. __all__.append('ResizeMethod') diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 9397b83e7c..d27cefc61d 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -239,7 +239,7 @@ def abs(x, name=None): number. Args: - x: A `Tensor` or `SparseTensor` of type `float`, `double`, `int32`, or + x: A `Tensor` or `SparseTensor` of type `float32`, `float64`, `int32`, or `int64`. name: A name for the operation (optional). @@ -352,7 +352,7 @@ def complex_abs(x, name=None): r"""Computes the complex absolute value of a tensor. Given a tensor `x` of complex numbers, this operation returns a tensor of type - `float` or `double` that is the absolute value of each element in `x`. All + `float32` or `float64` that is the absolute value of each element in `x`. All elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute value is computed as \\( \sqrt{a^2 + b^2}\\). @@ -414,10 +414,10 @@ def pow(x, y, name=None): ``` Args: - x: A `Tensor` of type `float`, `double`, `int32`, `int64`, `complex64`, or - `complex128`. - y: A `Tensor` of type `float`, `double`, `int32`, `int64`, `complex64`, or - `complex128`. + x: A `Tensor` of type `float32`, `float64`, `int32`, `int64`, `complex64`, + or `complex128`. + y: A `Tensor` of type `float32`, `float64`, `int32`, `int64`, `complex64`, + or `complex128`. name: A name for the operation (optional). Returns: @@ -471,7 +471,7 @@ def real(input, name=None): """Returns the real part of a complex number. Given a tensor `input` of complex numbers, this operation returns a tensor of - type `float` or `double` that is the real part of each element in `input`. + type `float32` or `float64` that is the real part of each element in `input`. All elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real part returned by this operation and *b* is the imaginary part. @@ -489,7 +489,7 @@ def real(input, name=None): name: A name for the operation (optional). Returns: - A `Tensor` of type `float` or `double`. + A `Tensor` of type `float32` or `float64`. """ with ops.op_scope([input], name, "Real") as name: return gen_math_ops.real(input, Tout=input.dtype.real_dtype, name=name) @@ -499,7 +499,7 @@ def imag(input, name=None): """Returns the imaginary part of a complex number. Given a tensor `input` of complex numbers, this operation returns a tensor of - type `float` or `double` that is the imaginary part of each element in + type `float32` or `float64` that is the imaginary part of each element in `input`. All elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real part and *b* is the imaginary part returned by this operation. @@ -516,7 +516,7 @@ def imag(input, name=None): name: A name for the operation (optional). Returns: - A `Tensor` of type `float` or `double`. + A `Tensor` of type `float32` or `float64`. """ with ops.op_scope([input], name, "Imag") as name: return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name) @@ -533,7 +533,7 @@ def round(x, name=None): ``` Args: - x: A `Tensor` of type `float` or `double`. + x: A `Tensor` of type `float32` or `float64`. name: A name for the operation (optional). Returns: @@ -1255,7 +1255,7 @@ def matmul(a, b, possibly after transposition. Both matrices must be of the same type. The supported types are: - `float`, `double`, `int32`, `complex64`. + `float32`, `float64`, `int32`, `complex64`. Either matrix can be transposed on the fly by setting the corresponding flag to `True`. This is `False` by default. @@ -1279,7 +1279,7 @@ def matmul(a, b, ``` Args: - a: `Tensor` of type `float`, `double`, `int32` or `complex64`. + a: `Tensor` of type `float32`, `float64`, `int32` or `complex64`. b: `Tensor` with same type as `a`. transpose_a: If `True`, `a` is transposed before multiplication. transpose_b: If `True`, `b` is transposed before multiplication. @@ -1531,7 +1531,7 @@ def sigmoid(x, name=None): Specifically, `y = 1 / (1 + exp(-x))`. Args: - x: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`, + x: A Tensor with type `float32`, `float64`, `int32`, `complex64`, `int64`, or `qint32`. name: A name for the operation (optional). @@ -1548,7 +1548,7 @@ def tanh(x, name=None): """Computes hyperbolic tangent of `x` element-wise. Args: - x: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`, + x: A Tensor with type `float32`, `float64`, `int32`, `complex64`, `int64`, or `qint32`. name: A name for the operation (optional). diff --git a/tensorflow/python/ops/nn_conv_test.py b/tensorflow/python/ops/nn_conv_test.py index 465e39350f..8c771fc2a3 100644 --- a/tensorflow/python/ops/nn_conv_test.py +++ b/tensorflow/python/ops/nn_conv_test.py @@ -283,11 +283,42 @@ class Conv2DTransposeTest(tf.test.TestCase): padding="SAME") err = tf.test.compute_gradient_error( [x, f], [x_shape, f_shape], output, y_shape) - print("DeConv gradient err = %g " % err) + print("conv2d_transpose gradient err = %g " % err) err_tolerance = 0.0005 self.assertLess(err, err_tolerance) +class Conv2DBackpropFilterGradTest(tf.test.TestCase): + + def testGradient(self): + with self.test_session(): + for padding in ["SAME", "VALID"]: + for stride in [1, 2]: + np.random.seed(1) + in_shape = [5, 8, 6, 4] + in_val = tf.constant( + 2 * np.random.random_sample(in_shape) - 1, + dtype=tf.float32) + filter_shape = [3, 3, 4, 6] + # Make a convolution op with the current settings, just to easily get + # the shape of the output. + conv_out = tf.nn.conv2d(in_val, tf.zeros(filter_shape), + [1, stride, stride, 1], padding) + out_backprop_shape = conv_out.get_shape().as_list() + out_backprop_val = tf.constant( + 2 * np.random.random_sample(out_backprop_shape) - 1, + dtype=tf.float32) + output = tf.nn.conv2d_backprop_filter(in_val, filter_shape, + out_backprop_val, + [1, stride, stride, 1], padding) + err = tf.test.compute_gradient_error([in_val, out_backprop_val], + [in_shape, out_backprop_shape], + output, filter_shape) + print("conv2d_backprop_filter gradient err = %g " % err) + err_tolerance = 1e-3 + self.assertLess(err, err_tolerance) + + class Conv1DTest(tf.test.TestCase): def testBasic(self): diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index ba2b9d957e..3f4bb0e068 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_nn_ops @ops.RegisterGradient("Conv2DBackpropInput") -def _Conv2DBackpropGrad(op, grad): +def _Conv2DBackpropInputGrad(op, grad): """The derivatives for deconvolution. Args: @@ -49,6 +49,25 @@ def _Conv2DBackpropGrad(op, grad): op.get_attr("data_format"))] +@ops.RegisterGradient("Conv2DBackpropFilter") +def _Conv2DBackpropFilterGrad(op, grad): + return [ + nn_ops.conv2d_backprop_input( + array_ops.shape(op.inputs[0]), grad, op.inputs[2], + op.get_attr("strides"), + op.get_attr("padding"), + op.get_attr("use_cudnn_on_gpu"), + op.get_attr("data_format")), + None, + nn_ops.conv2d( + op.inputs[0], grad, + op.get_attr("strides"), + op.get_attr("padding"), + op.get_attr("use_cudnn_on_gpu"), + op.get_attr("data_format")) + ] + + @ops.RegisterGradient("Conv3D") def _Conv3DGrad(op, grad): return [nn_ops.conv3d_backprop_input_v2(array_ops.shape(op.inputs[0]), diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index d8a96db0a2..d2832541dd 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -469,6 +469,125 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs, return (outputs, output_state_fw, output_state_bw) +def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, + initial_state_fw=None, initial_state_bw=None, + dtype=None, parallel_iterations=None, + swap_memory=False, time_major=False, scope=None): + """Creates a dynamic version of bidirectional recurrent neural network. + + Similar to the unidirectional case above (rnn) but takes input and builds + independent forward and backward RNNs. The input_size of forward and + backward cell must match. The initial state for both directions is zero by + default (but can be set optionally) and no intermediate states are ever + returned -- the network is fully unrolled for the given (passed in) + length(s) of the sequence(s) or completely unrolled if length(s) is not + given. + + Args: + cell_fw: An instance of RNNCell, to be used for forward direction. + cell_bw: An instance of RNNCell, to be used for backward direction. + inputs: The RNN inputs. + If time_major == False (default), this must be a tensor of shape: + `[batch_size, max_time, input_size]`. + If time_major == True, this must be a tensor of shape: + `[max_time, batch_size, input_size]`. + [batch_size, input_size]. + sequence_length: An int32/int64 vector, size `[batch_size]`, + containing the actual lengths for each of the sequences. + initial_state_fw: (optional) An initial state for the forward RNN. + This must be a tensor of appropriate type and shape + `[batch_size x cell_fw.state_size]`. + If `cell_fw.state_size` is a tuple, this should be a tuple of + tensors having shapes `[batch_size, s] for s in cell_fw.state_size`. + initial_state_bw: (optional) Same as for `initial_state_fw`, but using + the corresponding properties of `cell_bw`. + parallel_iterations: (Default: 32). The number of iterations to run in + parallel. Those operations which do not have any temporal dependency + and can be run in parallel, will be. This parameter trades off + time for space. Values >> 1 use more memory but take less time, + while smaller values use less memory but computations take longer. + swap_memory: Transparently swap the tensors produced in forward inference + but needed for back prop from GPU to CPU. This allows training RNNs + which would typically not fit on a single GPU, with very minimal (or no) + performance penalty. + time_major: The shape format of the `inputs` and `outputs` Tensors. + If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. + If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. + Using `time_major = True` is a bit more efficient because it avoids + transposes at the beginning and end of the RNN calculation. However, + most TensorFlow data is batch-major, so by default this function + accepts input and emits output in batch-major form. + dtype: (optional) The data type for the initial state. Required if + initial_state is not provided. + sequence_length: An int32/int64 vector, size `[batch_size]`, + containing the actual lengths for each of the sequences. + either of the initial states are not provided. + scope: VariableScope for the created subgraph; defaults to "BiRNN" + + Returns: + A tuple (outputs, output_states) where: + outputs: A tuple (output_fw, output_bw) containing the forward and + the backward rnn output `Tensor`. + If time_major == False (default), + output_fw will be a `Tensor` shaped: + `[batch_size, max_time, cell_fw.output_size]` + and output_bw will be a `Tensor` shaped: + `[batch_size, max_time, cell_bw.output_size]`. + If time_major == True, + output_fw will be a `Tensor` shaped: + `[max_time, batch_size, cell_fw.output_size]` + and output_bw will be a `Tensor` shaped: + `[max_time, batch_size, cell_bw.output_size]`. + It returns a tuple instead of a single concatenated `Tensor`, unlike + in the `bidirectional_rnn`. If the concatenated one is preferred, + the forward and backward outputs can be concatenated as + `tf.concat(2, outputs)`. + output_states: A tuple (output_state_fw, output_state_bw) containing + the forward and the backward final states of bidirectional rnn. + + Raises: + TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. + """ + + if not isinstance(cell_fw, rnn_cell.RNNCell): + raise TypeError("cell_fw must be an instance of RNNCell") + if not isinstance(cell_bw, rnn_cell.RNNCell): + raise TypeError("cell_bw must be an instance of RNNCell") + + name = scope or "BiRNN" + # Forward direction + with vs.variable_scope(name + "_FW") as fw_scope: + output_fw, output_state_fw = dynamic_rnn( + cell=cell_fw, inputs=inputs, sequence_length=sequence_length, + initial_state=initial_state_fw, dtype=dtype, + parallel_iterations=parallel_iterations, swap_memory=swap_memory, + time_major=time_major, scope=fw_scope) + # Backward direction + if not time_major: + time_dim = 1 + batch_dim = 0 + else: + time_dim = 0 + batch_dim = 1 + with vs.variable_scope(name + "_BW") as bw_scope: + inputs_reverse = array_ops.reverse_sequence( + input=inputs, seq_lengths=sequence_length, + seq_dim=time_dim, batch_dim=batch_dim) + tmp, output_state_bw = dynamic_rnn( + cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, + initial_state=initial_state_bw, dtype=dtype, + parallel_iterations=parallel_iterations, swap_memory=swap_memory, + time_major=time_major, scope=bw_scope) + output_bw = array_ops.reverse_sequence( + input=tmp, seq_lengths=sequence_length, + seq_dim = time_dim, batch_dim=batch_dim) + + outputs = (output_fw, output_bw) + output_states = (output_state_fw, output_state_bw) + + return (outputs, output_states) + + def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index efb1fc919f..4be42ce1cb 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -877,6 +877,15 @@ def local_variables(): return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES) +def model_variables(): + """Returns all variables in the MODEL_VARIABLES collection. + + Returns: + A list of local Variable objects. + """ + return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES) + + def moving_average_variables(): """Returns all variables that maintain their moving averages. diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index 63bac26b52..0274d62c48 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -59,6 +59,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.client import device_lib from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest from tensorflow.python.util.all_util import make_all @@ -97,6 +98,11 @@ def is_built_with_cuda(): return test_util.IsGoogleCudaEnabled() +def is_gpu_available(): + """Returns whether TensorFlow can access a GPU.""" + return any(x.device_type() == 'GPU' for x in device_lib.list_local_devices()) + + __all__ = make_all(__name__) # TODO(irving,vrv): Remove once TestCase is documented __all__.append('TestCase') diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 6ecd20ad48..9cc1d917ab 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -266,8 +266,10 @@ class ExponentialMovingAverage(object): if var_list is None: var_list = variables.trainable_variables() for var in var_list: - if var.dtype.base_dtype not in [dtypes.float32, dtypes.float64]: - raise TypeError("The variables must be float or double: %s" % var.name) + if var.dtype.base_dtype not in [dtypes.float16, dtypes.float32, + dtypes.float64]: + raise TypeError("The variables must be half, float, or double: %s" % + var.name) if var in self._averages: raise ValueError("Moving average already computed for: %s" % var.name) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 2a943b5911..1d6a8a6632 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -285,7 +285,7 @@ class BaseSaverBuilder(object): def _AddShardedRestoreOps(self, filename_tensor, per_device, restore_sequentially, reshape): - """Add Ops to save variables from multiple devices. + """Add Ops to restore variables from multiple devices. Args: filename_tensor: Tensor for the path of the file to load. diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py index d939b5aabd..94fa6c295f 100644 --- a/tensorflow/python/training/server_lib_test.py +++ b/tensorflow/python/training/server_lib_test.py @@ -52,6 +52,154 @@ class GrpcServerTest(tf.test.TestCase): sess_2.close() # TODO(mrry): Add `server.stop()` and `server.join()` when these work. + # Verifies behavior of multiple variables with multiple sessions connecting to + # the same server. + def testSameVariablesNoClear(self): + server = tf.train.Server.create_local_server() + + with tf.Session(server.target) as sess_1: + v0 = tf.Variable([[2, 1]], name="v0") + v1 = tf.Variable([[1], [2]], name="v1") + v2 = tf.matmul(v0, v1) + sess_1.run([v0.initializer, v1.initializer]) + self.assertAllEqual([[4]], sess_1.run(v2)) + + with tf.Session(server.target) as sess_2: + new_v0 = tf.get_default_graph().get_tensor_by_name("v0:0") + new_v1 = tf.get_default_graph().get_tensor_by_name("v1:0") + new_v2 = tf.matmul(new_v0, new_v1) + self.assertAllEqual([[4]], sess_2.run(new_v2)) + + # Verifies behavior of tf.Session.reset(). + def testSameVariablesClear(self): + server = tf.train.Server.create_local_server() + + # Creates a graph with 2 variables. + v0 = tf.Variable([[2, 1]], name="v0") + v1 = tf.Variable([[1], [2]], name="v1") + v2 = tf.matmul(v0, v1) + + # Verifies that both sessions connecting to the same target return + # the same results. + sess_1 = tf.Session(server.target) + sess_2 = tf.Session(server.target) + sess_1.run(tf.initialize_all_variables()) + self.assertAllEqual([[4]], sess_1.run(v2)) + self.assertAllEqual([[4]], sess_2.run(v2)) + + # Resets target. sessions abort. Use sess_2 to verify. + tf.Session.reset(server.target) + with self.assertRaises(tf.errors.AbortedError): + self.assertAllEqual([[4]], sess_2.run(v2)) + + # Connects to the same target. Device memory for the variables would have + # been released, so they will be unitialized. + sess_2 = tf.Session(server.target) + with self.assertRaises(tf.errors.FailedPreconditionError): + sess_2.run(v2) + # Reinitialzes the variables. + sess_2.run(tf.initialize_all_variables()) + self.assertAllEqual([[4]], sess_2.run(v2)) + sess_2.close() + + # Verifies behavior of tf.Session.reset() with multiple containers using + # default container names as defined by the target name. + def testSameVariablesClearContainer(self): + # Starts two servers with different names so they map to different + # resource "containers". + server0 = tf.train.Server({"local0": ["localhost:0"]}, protocol="grpc", + start=True) + server1 = tf.train.Server({"local1": ["localhost:0"]}, protocol="grpc", + start=True) + + # Creates a graph with 2 variables. + v0 = tf.Variable(1.0, name="v0") + v1 = tf.Variable(2.0, name="v0") + + # Initializes the variables. Verifies that the values are correct. + sess_0 = tf.Session(server0.target) + sess_1 = tf.Session(server1.target) + sess_0.run(v0.initializer) + sess_1.run(v1.initializer) + self.assertAllEqual(1.0, sess_0.run(v0)) + self.assertAllEqual(2.0, sess_1.run(v1)) + + # Resets container "local0". Verifies that v0 is no longer initialized. + tf.Session.reset(server0.target, ["local0"]) + sess = tf.Session(server0.target) + with self.assertRaises(tf.errors.FailedPreconditionError): + sess.run(v0) + # Reinitializes v0 for the following test. + sess.run(v0.initializer) + + # Verifies that v1 is still valid. + self.assertAllEqual(2.0, sess_1.run(v1)) + + # Resets container "local1". Verifies that v1 is no longer initialized. + tf.Session.reset(server1.target, ["local1"]) + sess = tf.Session(server1.target) + with self.assertRaises(tf.errors.FailedPreconditionError): + sess.run(v1) + # Verifies that v0 is still valid. + sess = tf.Session(server0.target) + self.assertAllEqual(1.0, sess.run(v0)) + + # Verifies behavior of tf.Session.reset() with multiple containers using + # tf.container. + def testMultipleContainers(self): + with tf.container("test0"): + v0 = tf.Variable(1.0, name="v0") + with tf.container("test1"): + v1 = tf.Variable(2.0, name="v0") + server = tf.train.Server.create_local_server() + sess = tf.Session(server.target) + sess.run(tf.initialize_all_variables()) + self.assertAllEqual(1.0, sess.run(v0)) + self.assertAllEqual(2.0, sess.run(v1)) + + # Resets container. Session aborts. + tf.Session.reset(server.target, ["test0"]) + with self.assertRaises(tf.errors.AbortedError): + sess.run(v1) + + # Connects to the same target. Device memory for the v0 would have + # been released, so it will be unitialized. But v1 should still + # be valid. + sess = tf.Session(server.target) + with self.assertRaises(tf.errors.FailedPreconditionError): + sess.run(v0) + self.assertAllEqual(2.0, sess.run(v1)) + + # Verifies various reset failures. + def testResetFails(self): + # Creates variable with container name. + with tf.container("test0"): + v0 = tf.Variable(1.0, name="v0") + # Creates variable with default container. + v1 = tf.Variable(2.0, name="v1") + # Verifies resetting the non-existent target returns error. + with self.assertRaises(tf.errors.NotFoundError): + tf.Session.reset("nonexistent", ["test0"]) + + # Verifies resetting with config. + # Verifies that resetting target with no server times out. + with self.assertRaises(tf.errors.DeadlineExceededError): + tf.Session.reset("grpc://localhost:0", ["test0"], + config=tf.ConfigProto(operation_timeout_in_ms=5)) + + # Verifies no containers are reset with non-existent container. + server = tf.train.Server.create_local_server() + sess = tf.Session(server.target) + sess.run(tf.initialize_all_variables()) + self.assertAllEqual(1.0, sess.run(v0)) + self.assertAllEqual(2.0, sess.run(v1)) + # No container is reset, but the server is reset. + tf.Session.reset(server.target, ["test1"]) + # Verifies that both variables are still valid. + sess = tf.Session(server.target) + self.assertAllEqual(1.0, sess.run(v0)) + self.assertAllEqual(2.0, sess.run(v1)) + def testLargeConstant(self): server = tf.train.Server.create_local_server() with tf.Session(server.target) as sess: |