From 3dfd14421d71c1d6a79f72217cd7b6510cbcb38f Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 4 Dec 2015 09:54:09 -0800 Subject: TensorFlow: upstream changes to git. Change 109418220 Update WORKSPACE to use gmock.BUILD from google/protobuf instead of a duplicate. Update google/protobuf's commit hash to include damieng@'s commit. Change 109417314 TensorFlow: add .gitignore to ignore some in-tree modified files. Change 109400051 Optionally build full TensorFlow for Android. 1. --define ANDROID_TYPES=__ANDROID_TYPES_FULL__ to register ops for all types, not just float. Today this increases codesize by ~700K when compiled for ARM, though only for clients who request full type support. 2. Add more ops to android_extended_ops, sufficient to train on the linear regression baseball codelab. Change 109388118 Fix the option changed in templatize. Oops. Change 109382553 Allows setting a function name in an op's attr in the py frontend. Change 109380896 Remove assert_same_graph in favor of op_scope. Change the latter to handle tensor-like objects such as SparseTensor, IndexedSlices, and Variable. Base CL: 109418322 --- .gitignore | 11 +++ WORKSPACE | 2 +- tensorflow/core/BUILD | 7 ++ tensorflow/core/kernels/cwise_ops_common.h | 13 ++- tensorflow/python/framework/ops.py | 145 +++++++++++++++++------------ tensorflow/python/framework/ops_test.py | 108 +++++++++++++++------ tensorflow/python/ops/op_def_library.py | 4 + 7 files changed, 195 insertions(+), 95 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..a9401de1f6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +node_modules +/bazel-bin +/bazel-genfiles +/bazel-out +/bazel-tensorflow +/bazel-testlogs +/bazel-tf +/third_party/py/numpy/numpy_include +/tools/bazel.rc +/util/python/python_include +/util/python/python_lib diff --git a/WORKSPACE b/WORKSPACE index 38993d5816..d3dd81791c 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -16,7 +16,7 @@ new_http_archive( name = "gmock_archive", url = "https://googlemock.googlecode.com/files/gmock-1.7.0.zip", sha256 = "26fcbb5925b74ad5fc8c26b0495dfc96353f4d553492eb97e85a8a6d2f43095b", - build_file = "gmock.BUILD", + build_file = "google/protobuf/gmock.BUILD", ) bind( diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 8d0e71efd6..ce19032104 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -632,6 +632,7 @@ filegroup( srcs = [ "//tensorflow/core:kernels/avgpooling_op.cc", "//tensorflow/core:kernels/avgpooling_op.h", + "//tensorflow/core:kernels/bcast_ops.cc", "//tensorflow/core:kernels/control_flow_ops.cc", "//tensorflow/core:kernels/control_flow_ops.h", "//tensorflow/core:kernels/conv_2d.h", @@ -642,19 +643,23 @@ filegroup( "//tensorflow/core:kernels/cwise_op_less.cc", "//tensorflow/core:kernels/cwise_op_log.cc", "//tensorflow/core:kernels/cwise_op_mul.cc", + "//tensorflow/core:kernels/cwise_op_neg.cc", "//tensorflow/core:kernels/cwise_op_sigmoid.cc", "//tensorflow/core:kernels/cwise_op_sqrt.cc", "//tensorflow/core:kernels/cwise_op_square.cc", "//tensorflow/core:kernels/cwise_op_sub.cc", "//tensorflow/core:kernels/cwise_op_tanh.cc", "//tensorflow/core:kernels/dynamic_partition_op.cc", + "//tensorflow/core:kernels/dynamic_stitch_op.cc", "//tensorflow/core:kernels/lrn_op.cc", "//tensorflow/core:kernels/maxpooling_op.cc", "//tensorflow/core:kernels/maxpooling_op.h", "//tensorflow/core:kernels/reduction_ops.h", "//tensorflow/core:kernels/reduction_ops_common.h", "//tensorflow/core:kernels/reduction_ops_max.cc", + "//tensorflow/core:kernels/reduction_ops_mean.cc", "//tensorflow/core:kernels/reduction_ops_min.cc", + "//tensorflow/core:kernels/reduction_ops_prod.cc", "//tensorflow/core:kernels/reduction_ops_sum.cc", "//tensorflow/core:kernels/relu_op.cc", "//tensorflow/core:kernels/relu_op.h", @@ -663,6 +668,8 @@ filegroup( "//tensorflow/core:kernels/softsign_op.cc", "//tensorflow/core:kernels/softsign_op.h", "//tensorflow/core:kernels/stack_ops.cc", + "//tensorflow/core:kernels/tile_ops.cc", + "//tensorflow/core:kernels/tile_ops.h", "//tensorflow/core:kernels/transpose_op.cc", "//tensorflow/core:kernels/transpose_op.h", "//tensorflow/core:kernels/transpose_op_functor.h", diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h index adf4203322..0b57207832 100644 --- a/tensorflow/core/kernels/cwise_ops_common.h +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -367,11 +367,14 @@ struct SelectFunctor { OP>); // Macros to register kernels for multiple types (T0, T1, etc.) on -// device type "D" (CPU or GPU) for operatin "N" (e.g., sqrt) using +// device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using // the functor "F" (e.g., functor:sqrt). -#if defined(__ANDROID__) -// On Android, only register the first type (float) +#if defined(__ANDROID_TYPES_SLIM__) +// Normally Android TensorFlow is built with a reduced number of types (float). +// Override on the command-line "--define ANDROID_TYPES=__ANDROID_TYPES_FULL__" +// to generate a library with full type support with a consequent increase in +// code size. #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0) #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0) #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0) @@ -381,7 +384,7 @@ struct SelectFunctor { REGISTER(OP, D, N, F, T0) #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \ REGISTER(OP, D, N, F, T0) -#else // !defined(__ANDROID__) +#else // !defined(__ANDROID_TYPES_SLIM__) #define REGISTER2(OP, D, N, F, T0, T1) \ REGISTER(OP, D, N, F, T0) \ REGISTER(OP, D, N, F, T1) @@ -403,7 +406,7 @@ struct SelectFunctor { #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \ REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ REGISTER4(OP, D, N, F, T4, T5, T6, T7) -#endif // defined(__ANDROID__) +#endif // defined(__ANDROID_TYPES_SLIM__) } // end namespace tensorflow diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f9796ca679..8a1bcac0aa 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -39,6 +39,7 @@ from tensorflow.python.framework import registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions from tensorflow.python.util import compat +from tensorflow.python.platform import logging def _convert_stack(stack): @@ -95,6 +96,22 @@ def _extract_stack(): return ret +def _as_graph_element(obj): + """Convert `obj` to a graph element if possible, otherwise return `None`. + + Args: + obj: Object to convert. + + Returns: + The result of `obj._as_graph_element()` if that method is available; + otherwise `None`. + """ + conv_fn = getattr(obj, "_as_graph_element", None) + if conv_fn and callable(conv_fn): + return conv_fn() + return None + + class Tensor(object): """Represents a value produced by an `Operation`. @@ -680,6 +697,7 @@ class IndexedSlices(object): def __init__(self, values, indices, dense_shape=None): """Creates an `IndexedSlices`.""" + _get_graph_from_inputs([values, indices, dense_shape]) self._values = values self._indices = indices self._dense_shape = dense_shape @@ -719,30 +737,15 @@ class IndexedSlices(object): """The `DType` of elements in this tensor.""" return self.values.dtype - def __str__(self): - return "IndexedSlices(indices=%s, values=%s)" % ( - self._indices, self._values) - - -def assert_same_graph(items, expected_graph=None): - """Asserts all items are from the same graph. + @property + def graph(self): + """The `Graph` that contains the values, indices, and shape tensors.""" + return self._values.graph - Args: - items: List of graph items (e.g., Variable, Tensor, SparseTensor, - Operation, or IndexedSlices). - expected_graph: Expected graph. If not specified, assert all tensors are - from the same graph. - Returns: - items, for chaining. - Raises: - ValueError: If any graphs do not match. - """ - for item in items: - if not expected_graph: - expected_graph = item.graph - elif expected_graph != item.graph: - raise ValueError("Items must be from the same graph.") - return items + def __str__(self): + return "IndexedSlices(indices=%s, values=%s%s)" % ( + self._indices, self._values, + (", dense_shape=%s" % self._dense_shape) if self._dense_shape else "") class SparseTensor(object): @@ -1106,7 +1109,7 @@ class Operation(object): """ if not isinstance(tensor, Tensor): raise TypeError("tensor must be a Tensor: %s" % tensor) - assert_same_graph([self, tensor]) + _assert_same_graph(self, tensor) if dtype is None: dtype = tensor.dtype else: @@ -1138,7 +1141,7 @@ class Operation(object): """ if not isinstance(tensor, Tensor): raise TypeError("tensor must be a Tensor: %s" % tensor) - assert_same_graph([self, tensor]) + _assert_same_graph(self, tensor) if dtype is None: dtype = tensor.dtype else: @@ -1166,7 +1169,7 @@ class Operation(object): """ if not isinstance(op, Operation): raise TypeError("op must be an Operation: %s" % op) - assert_same_graph([self, op]) + _assert_same_graph(self, op) self._control_inputs.append(op) self._recompute_node_def() @@ -1887,9 +1890,7 @@ class Graph(object): else: raise ValueError("allow_tensor and allow_operation can't both be False.") - conv_fn = getattr(obj, "_as_graph_element", None) - if conv_fn and callable(conv_fn): - obj = conv_fn() + obj = _as_graph_element(obj) or obj # If obj appears to be a name... if isinstance(obj, compat.bytes_or_text_types): @@ -2971,6 +2972,21 @@ def get_default_graph(): return _default_graph_stack.get_default() +def _assert_same_graph(original_item, item): + """Fail if the 2 items are from different graphs. + + Args: + original_item: Original item to check against. + item: Item to check. + + Raises: + ValueError: if graphs do not match. + """ + if original_item.graph is not item.graph: + raise ValueError( + "%s must be from the same graph as %s." % (item, original_item)) + + def _get_graph_from_inputs(op_input_list, graph=None): """Returns the appropriate graph to use for the given inputs. @@ -2986,8 +3002,8 @@ def _get_graph_from_inputs(op_input_list, graph=None): "op_input_list", we attempt to use the default graph. Args: - op_input_list: A list of inputs to an operation, which may include Tensor - and Operation objects. + op_input_list: A list of inputs to an operation, which may include `Tensor`, + `Operation`, and other objects that may be converted to a graph element. graph: (Optional) The explicit graph to use. Raises: @@ -3001,37 +3017,35 @@ def _get_graph_from_inputs(op_input_list, graph=None): The appropriate graph to use for the given inputs. """ op_input_list = tuple(op_input_list) # Handle generators correctly - - # 1. If the graph is specified explicitly, we validate that all of the inputs - # are compatible with that graph. - if graph is not None: - if not isinstance(graph, Graph): - raise TypeError("Input graph needs to be a Graph: %s" % graph) - for op_input in op_input_list: - if isinstance(op_input, Operation): - if op_input.graph is not graph: - raise ValueError("Operation %s is not from the passed-in graph" - % op_input) - elif isinstance(op_input, Tensor): - if op_input.graph is not graph: - raise ValueError("Tensor %s is not from the passed-in graph" - % op_input) - return graph - - # 2. Otherwise, we attempt to select a graph from one of the Operation- - # or Tensor-valued inputs. - original_input = None + if graph and not isinstance(graph, Graph): + raise TypeError("Input graph needs to be a Graph: %s" % graph) + + # 1. We validate that all of the inputs are from the same graph. This is + # either the supplied graph parameter, or the first one selected from one + # the graph-element-valued inputs. In the latter case, we hold onto + # that input in original_graph_element so we can provide a more + # informative error if a mismatch is found. + original_graph_element = None for op_input in op_input_list: - if isinstance(op_input, (Operation, Tensor)): - if original_input is None: - original_input = op_input - else: - assert_same_graph([original_input, op_input]) - if original_input is not None: - return original_input.graph + # Determine if this is a valid graph_element. + graph_element = None + if isinstance(op_input, (Operation, Tensor, SparseTensor, IndexedSlices)): + graph_element = op_input + else: + graph_element = _as_graph_element(op_input) - # 3. If all else fails, we use the default graph, which is always there. - return get_default_graph() + if graph_element: + if not graph: + original_graph_element = graph_element + graph = graph_element.graph + elif original_graph_element: + _assert_same_graph(original_graph_element, graph_element) + elif graph_element.graph is not graph: + raise ValueError( + "%s is not from the passed-in graph." % graph_element) + + # 2. If all else fails, we use the default graph, which is always there. + return graph or get_default_graph() class GraphKeys(object): @@ -3115,7 +3129,7 @@ def get_collection(key, scope=None): # pylint: disable=g-doc-return-or-yield @contextlib.contextmanager -def op_scope(values, name, default_name): +def op_scope(values, name, default_name=None): """Returns a context manager for use when defining a Python op. This context manager validates that the given `values` are from the @@ -3140,10 +3154,17 @@ def op_scope(values, name, default_name): default_name: The default name to use if the `name` argument is `None`. Returns: - A context manager for use in defining a Python op. + A context manager for use in defining Python ops. Yields the name scope. + + Raises: + ValueError: if neither `name` nor `default_name` is provided. """ g = _get_graph_from_inputs(values) n = default_name if name is None else name + if n is None: + raise ValueError( + "At least one of name (%s) and default_name (%s) must be provided." % ( + name, default_name)) with g.as_default(), g.name_scope(n) as scope: yield scope # pylint: enable=g-doc-return-or-yield diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 43044b1d39..5831ccd108 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_kernel_label_op from tensorflow.python.framework import test_util from tensorflow.python.ops import common_shapes +from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -356,19 +357,19 @@ class NameTest(test_util.TensorFlowTestCase): self.assertEqual("my_op", op2.name) self.assertEqual("my_op:0", op2.outputs[0].name) - def testname_scope(self): + def testNameScope(self): g = ops.Graph() with g.name_scope("foo") as foo: - self.assertEqual(foo, "foo/") + self.assertEqual("foo/", foo) with g.name_scope("foo2") as foo2: - self.assertEqual(foo2, "foo/foo2/") + self.assertEqual("foo/foo2/", foo2) with g.name_scope(None) as empty1: - self.assertEqual(empty1, "") + self.assertEqual("", empty1) with g.name_scope("foo3") as foo3: - self.assertEqual(foo3, "foo3/") + self.assertEqual("foo3/", foo3) with g.name_scope("") as empty2: - self.assertEqual(empty2, "") + self.assertEqual("", empty2) self.assertEqual("const", g.create_op("const", [], [dtypes.float32]).name) @@ -792,6 +793,80 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): self.assertEqual(b.op.control_inputs, []) +class OpScopeTest(test_util.TensorFlowTestCase): + + def testNoScopeName(self): + g0 = ops.Graph() + values = [ + g0.create_op("a", [], [dtypes.float32]), + g0.create_op("b", [], [dtypes.float32])] + with self.assertRaises(ValueError): + with ops.op_scope(values, None): + pass + with self.assertRaises(ValueError): + with ops.op_scope(values, None, None): + pass + + def testEmptyScopeName(self): + g0 = ops.Graph() + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) + with ops.op_scope([a, b], "") as scope: + self.assertEqual("", scope) + self.assertEqual(g0, ops.get_default_graph()) + with ops.op_scope([a, b], "", "my_default_scope") as scope: + self.assertEqual("", scope) + self.assertEqual(g0, ops.get_default_graph()) + + def testDefaultScopeName(self): + g0 = ops.Graph() + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) + scope_name = "my_scope" + default_scope_name = "my_default_scope" + with ops.op_scope([a, b], scope_name, default_scope_name) as scope: + self.assertEqual("%s/" % scope_name, scope) + self.assertEqual(g0, ops.get_default_graph()) + with ops.op_scope([a, b], None, default_scope_name) as scope: + self.assertEqual("%s/" % default_scope_name, scope) + self.assertEqual(g0, ops.get_default_graph()) + + def _testGraphElements(self, graph_elements): + scope_name = "my_scope" + with ops.op_scope(graph_elements, scope_name) as scope: + self.assertEqual("%s/" % scope_name, scope) + self.assertEqual(graph_elements[0].graph, ops.get_default_graph()) + g1 = ops.Graph() + c = g1.create_op("c", [], [dtypes.float32]) + with self.assertRaises(ValueError): + with ops.op_scope(graph_elements + [c], scope_name): + pass + + def testTensor(self): + g0 = ops.Graph() + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) + self._testGraphElements([a, b]) + + def testSparseTensor(self): + g0 = ops.Graph() + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) + sparse = ops.SparseTensor( + _apply_op(g0, "const", [], [dtypes.int64]), + _apply_op(g0, "const", [], [dtypes.float32]), + _apply_op(g0, "const", [], [dtypes.int64])) + self._testGraphElements([a, sparse, b]) + + def testVariable(self): + g0 = ops.Graph() + with g0.as_default(): + variable = variables.Variable([1.0]) + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) + self._testGraphElements([a, variable, b]) + + class GraphTest(test_util.TensorFlowTestCase): def setUp(self): @@ -835,27 +910,6 @@ class GraphTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError): g.as_graph_element(NonConvertibleObj()) - def testAssertSameGraph(self): - g0 = ops.Graph() - a = g0.create_op("a", [], [dtypes.float32]) - b = g0.create_op("b", [], [dtypes.float32]) - ops.assert_same_graph([a, b]) - ops.assert_same_graph([a, b], g0) - g1 = ops.Graph() - c = g1.create_op("c", [], [dtypes.float32]) - self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c]) - self.assertRaises(ValueError, ops.assert_same_graph, [c], g0) - self.assertRaises(ValueError, ops.assert_same_graph, [a], g1) - - sparse = ops.SparseTensor( - _apply_op(g0, "const", [], [dtypes.int64]), - _apply_op(g0, "const", [], [dtypes.float32]), - _apply_op(g0, "const", [], [dtypes.int64])) - ops.assert_same_graph([sparse, a, b]) - ops.assert_same_graph([sparse, a, b], g0) - self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c]) - self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c], g1) - ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py index c2ad3bdb58..946d543b44 100644 --- a/tensorflow/python/ops/op_def_library.py +++ b/tensorflow/python/ops/op_def_library.py @@ -616,6 +616,10 @@ class OpDefLibrary(object): elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) + elif attr_def.type == "func": + if not isinstance(value, compat.bytes_or_text_types): + raise TypeError("Expects a string for the func name") + attr_value.func.name = value else: raise TypeError("Unrecognized Attr type " + attr_def.type) -- cgit v1.2.3