aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.gitignore11
-rw-r--r--WORKSPACE2
-rw-r--r--tensorflow/core/BUILD7
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h13
-rw-r--r--tensorflow/python/framework/ops.py145
-rw-r--r--tensorflow/python/framework/ops_test.py108
-rw-r--r--tensorflow/python/ops/op_def_library.py4
7 files changed, 195 insertions, 95 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000..a9401de1f6
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,11 @@
+node_modules
+/bazel-bin
+/bazel-genfiles
+/bazel-out
+/bazel-tensorflow
+/bazel-testlogs
+/bazel-tf
+/third_party/py/numpy/numpy_include
+/tools/bazel.rc
+/util/python/python_include
+/util/python/python_lib
diff --git a/WORKSPACE b/WORKSPACE
index 38993d5816..d3dd81791c 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -16,7 +16,7 @@ new_http_archive(
name = "gmock_archive",
url = "https://googlemock.googlecode.com/files/gmock-1.7.0.zip",
sha256 = "26fcbb5925b74ad5fc8c26b0495dfc96353f4d553492eb97e85a8a6d2f43095b",
- build_file = "gmock.BUILD",
+ build_file = "google/protobuf/gmock.BUILD",
)
bind(
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 8d0e71efd6..ce19032104 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -632,6 +632,7 @@ filegroup(
srcs = [
"//tensorflow/core:kernels/avgpooling_op.cc",
"//tensorflow/core:kernels/avgpooling_op.h",
+ "//tensorflow/core:kernels/bcast_ops.cc",
"//tensorflow/core:kernels/control_flow_ops.cc",
"//tensorflow/core:kernels/control_flow_ops.h",
"//tensorflow/core:kernels/conv_2d.h",
@@ -642,19 +643,23 @@ filegroup(
"//tensorflow/core:kernels/cwise_op_less.cc",
"//tensorflow/core:kernels/cwise_op_log.cc",
"//tensorflow/core:kernels/cwise_op_mul.cc",
+ "//tensorflow/core:kernels/cwise_op_neg.cc",
"//tensorflow/core:kernels/cwise_op_sigmoid.cc",
"//tensorflow/core:kernels/cwise_op_sqrt.cc",
"//tensorflow/core:kernels/cwise_op_square.cc",
"//tensorflow/core:kernels/cwise_op_sub.cc",
"//tensorflow/core:kernels/cwise_op_tanh.cc",
"//tensorflow/core:kernels/dynamic_partition_op.cc",
+ "//tensorflow/core:kernels/dynamic_stitch_op.cc",
"//tensorflow/core:kernels/lrn_op.cc",
"//tensorflow/core:kernels/maxpooling_op.cc",
"//tensorflow/core:kernels/maxpooling_op.h",
"//tensorflow/core:kernels/reduction_ops.h",
"//tensorflow/core:kernels/reduction_ops_common.h",
"//tensorflow/core:kernels/reduction_ops_max.cc",
+ "//tensorflow/core:kernels/reduction_ops_mean.cc",
"//tensorflow/core:kernels/reduction_ops_min.cc",
+ "//tensorflow/core:kernels/reduction_ops_prod.cc",
"//tensorflow/core:kernels/reduction_ops_sum.cc",
"//tensorflow/core:kernels/relu_op.cc",
"//tensorflow/core:kernels/relu_op.h",
@@ -663,6 +668,8 @@ filegroup(
"//tensorflow/core:kernels/softsign_op.cc",
"//tensorflow/core:kernels/softsign_op.h",
"//tensorflow/core:kernels/stack_ops.cc",
+ "//tensorflow/core:kernels/tile_ops.cc",
+ "//tensorflow/core:kernels/tile_ops.h",
"//tensorflow/core:kernels/transpose_op.cc",
"//tensorflow/core:kernels/transpose_op.h",
"//tensorflow/core:kernels/transpose_op_functor.h",
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index adf4203322..0b57207832 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -367,11 +367,14 @@ struct SelectFunctor<CPUDevice, T> {
OP<D##Device, F<T>>);
// Macros to register kernels for multiple types (T0, T1, etc.) on
-// device type "D" (CPU or GPU) for operatin "N" (e.g., sqrt) using
+// device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
// the functor "F" (e.g., functor:sqrt).
-#if defined(__ANDROID__)
-// On Android, only register the first type (float)
+#if defined(__ANDROID_TYPES_SLIM__)
+// Normally Android TensorFlow is built with a reduced number of types (float).
+// Override on the command-line "--define ANDROID_TYPES=__ANDROID_TYPES_FULL__"
+// to generate a library with full type support with a consequent increase in
+// code size.
#define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
#define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
#define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
@@ -381,7 +384,7 @@ struct SelectFunctor<CPUDevice, T> {
REGISTER(OP, D, N, F, T0)
#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
REGISTER(OP, D, N, F, T0)
-#else // !defined(__ANDROID__)
+#else // !defined(__ANDROID_TYPES_SLIM__)
#define REGISTER2(OP, D, N, F, T0, T1) \
REGISTER(OP, D, N, F, T0) \
REGISTER(OP, D, N, F, T1)
@@ -403,7 +406,7 @@ struct SelectFunctor<CPUDevice, T> {
#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
REGISTER4(OP, D, N, F, T4, T5, T6, T7)
-#endif // defined(__ANDROID__)
+#endif // defined(__ANDROID_TYPES_SLIM__)
} // end namespace tensorflow
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f9796ca679..8a1bcac0aa 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -39,6 +39,7 @@ from tensorflow.python.framework import registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
from tensorflow.python.util import compat
+from tensorflow.python.platform import logging
def _convert_stack(stack):
@@ -95,6 +96,22 @@ def _extract_stack():
return ret
+def _as_graph_element(obj):
+ """Convert `obj` to a graph element if possible, otherwise return `None`.
+
+ Args:
+ obj: Object to convert.
+
+ Returns:
+ The result of `obj._as_graph_element()` if that method is available;
+ otherwise `None`.
+ """
+ conv_fn = getattr(obj, "_as_graph_element", None)
+ if conv_fn and callable(conv_fn):
+ return conv_fn()
+ return None
+
+
class Tensor(object):
"""Represents a value produced by an `Operation`.
@@ -680,6 +697,7 @@ class IndexedSlices(object):
def __init__(self, values, indices, dense_shape=None):
"""Creates an `IndexedSlices`."""
+ _get_graph_from_inputs([values, indices, dense_shape])
self._values = values
self._indices = indices
self._dense_shape = dense_shape
@@ -719,30 +737,15 @@ class IndexedSlices(object):
"""The `DType` of elements in this tensor."""
return self.values.dtype
- def __str__(self):
- return "IndexedSlices(indices=%s, values=%s)" % (
- self._indices, self._values)
-
-
-def assert_same_graph(items, expected_graph=None):
- """Asserts all items are from the same graph.
+ @property
+ def graph(self):
+ """The `Graph` that contains the values, indices, and shape tensors."""
+ return self._values.graph
- Args:
- items: List of graph items (e.g., Variable, Tensor, SparseTensor,
- Operation, or IndexedSlices).
- expected_graph: Expected graph. If not specified, assert all tensors are
- from the same graph.
- Returns:
- items, for chaining.
- Raises:
- ValueError: If any graphs do not match.
- """
- for item in items:
- if not expected_graph:
- expected_graph = item.graph
- elif expected_graph != item.graph:
- raise ValueError("Items must be from the same graph.")
- return items
+ def __str__(self):
+ return "IndexedSlices(indices=%s, values=%s%s)" % (
+ self._indices, self._values,
+ (", dense_shape=%s" % self._dense_shape) if self._dense_shape else "")
class SparseTensor(object):
@@ -1106,7 +1109,7 @@ class Operation(object):
"""
if not isinstance(tensor, Tensor):
raise TypeError("tensor must be a Tensor: %s" % tensor)
- assert_same_graph([self, tensor])
+ _assert_same_graph(self, tensor)
if dtype is None:
dtype = tensor.dtype
else:
@@ -1138,7 +1141,7 @@ class Operation(object):
"""
if not isinstance(tensor, Tensor):
raise TypeError("tensor must be a Tensor: %s" % tensor)
- assert_same_graph([self, tensor])
+ _assert_same_graph(self, tensor)
if dtype is None:
dtype = tensor.dtype
else:
@@ -1166,7 +1169,7 @@ class Operation(object):
"""
if not isinstance(op, Operation):
raise TypeError("op must be an Operation: %s" % op)
- assert_same_graph([self, op])
+ _assert_same_graph(self, op)
self._control_inputs.append(op)
self._recompute_node_def()
@@ -1887,9 +1890,7 @@ class Graph(object):
else:
raise ValueError("allow_tensor and allow_operation can't both be False.")
- conv_fn = getattr(obj, "_as_graph_element", None)
- if conv_fn and callable(conv_fn):
- obj = conv_fn()
+ obj = _as_graph_element(obj) or obj
# If obj appears to be a name...
if isinstance(obj, compat.bytes_or_text_types):
@@ -2971,6 +2972,21 @@ def get_default_graph():
return _default_graph_stack.get_default()
+def _assert_same_graph(original_item, item):
+ """Fail if the 2 items are from different graphs.
+
+ Args:
+ original_item: Original item to check against.
+ item: Item to check.
+
+ Raises:
+ ValueError: if graphs do not match.
+ """
+ if original_item.graph is not item.graph:
+ raise ValueError(
+ "%s must be from the same graph as %s." % (item, original_item))
+
+
def _get_graph_from_inputs(op_input_list, graph=None):
"""Returns the appropriate graph to use for the given inputs.
@@ -2986,8 +3002,8 @@ def _get_graph_from_inputs(op_input_list, graph=None):
"op_input_list", we attempt to use the default graph.
Args:
- op_input_list: A list of inputs to an operation, which may include Tensor
- and Operation objects.
+ op_input_list: A list of inputs to an operation, which may include `Tensor`,
+ `Operation`, and other objects that may be converted to a graph element.
graph: (Optional) The explicit graph to use.
Raises:
@@ -3001,37 +3017,35 @@ def _get_graph_from_inputs(op_input_list, graph=None):
The appropriate graph to use for the given inputs.
"""
op_input_list = tuple(op_input_list) # Handle generators correctly
-
- # 1. If the graph is specified explicitly, we validate that all of the inputs
- # are compatible with that graph.
- if graph is not None:
- if not isinstance(graph, Graph):
- raise TypeError("Input graph needs to be a Graph: %s" % graph)
- for op_input in op_input_list:
- if isinstance(op_input, Operation):
- if op_input.graph is not graph:
- raise ValueError("Operation %s is not from the passed-in graph"
- % op_input)
- elif isinstance(op_input, Tensor):
- if op_input.graph is not graph:
- raise ValueError("Tensor %s is not from the passed-in graph"
- % op_input)
- return graph
-
- # 2. Otherwise, we attempt to select a graph from one of the Operation-
- # or Tensor-valued inputs.
- original_input = None
+ if graph and not isinstance(graph, Graph):
+ raise TypeError("Input graph needs to be a Graph: %s" % graph)
+
+ # 1. We validate that all of the inputs are from the same graph. This is
+ # either the supplied graph parameter, or the first one selected from one
+ # the graph-element-valued inputs. In the latter case, we hold onto
+ # that input in original_graph_element so we can provide a more
+ # informative error if a mismatch is found.
+ original_graph_element = None
for op_input in op_input_list:
- if isinstance(op_input, (Operation, Tensor)):
- if original_input is None:
- original_input = op_input
- else:
- assert_same_graph([original_input, op_input])
- if original_input is not None:
- return original_input.graph
+ # Determine if this is a valid graph_element.
+ graph_element = None
+ if isinstance(op_input, (Operation, Tensor, SparseTensor, IndexedSlices)):
+ graph_element = op_input
+ else:
+ graph_element = _as_graph_element(op_input)
- # 3. If all else fails, we use the default graph, which is always there.
- return get_default_graph()
+ if graph_element:
+ if not graph:
+ original_graph_element = graph_element
+ graph = graph_element.graph
+ elif original_graph_element:
+ _assert_same_graph(original_graph_element, graph_element)
+ elif graph_element.graph is not graph:
+ raise ValueError(
+ "%s is not from the passed-in graph." % graph_element)
+
+ # 2. If all else fails, we use the default graph, which is always there.
+ return graph or get_default_graph()
class GraphKeys(object):
@@ -3115,7 +3129,7 @@ def get_collection(key, scope=None):
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
-def op_scope(values, name, default_name):
+def op_scope(values, name, default_name=None):
"""Returns a context manager for use when defining a Python op.
This context manager validates that the given `values` are from the
@@ -3140,10 +3154,17 @@ def op_scope(values, name, default_name):
default_name: The default name to use if the `name` argument is `None`.
Returns:
- A context manager for use in defining a Python op.
+ A context manager for use in defining Python ops. Yields the name scope.
+
+ Raises:
+ ValueError: if neither `name` nor `default_name` is provided.
"""
g = _get_graph_from_inputs(values)
n = default_name if name is None else name
+ if n is None:
+ raise ValueError(
+ "At least one of name (%s) and default_name (%s) must be provided." % (
+ name, default_name))
with g.as_default(), g.name_scope(n) as scope:
yield scope
# pylint: enable=g-doc-return-or-yield
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 43044b1d39..5831ccd108 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_kernel_label_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -356,19 +357,19 @@ class NameTest(test_util.TensorFlowTestCase):
self.assertEqual("my_op", op2.name)
self.assertEqual("my_op:0", op2.outputs[0].name)
- def testname_scope(self):
+ def testNameScope(self):
g = ops.Graph()
with g.name_scope("foo") as foo:
- self.assertEqual(foo, "foo/")
+ self.assertEqual("foo/", foo)
with g.name_scope("foo2") as foo2:
- self.assertEqual(foo2, "foo/foo2/")
+ self.assertEqual("foo/foo2/", foo2)
with g.name_scope(None) as empty1:
- self.assertEqual(empty1, "")
+ self.assertEqual("", empty1)
with g.name_scope("foo3") as foo3:
- self.assertEqual(foo3, "foo3/")
+ self.assertEqual("foo3/", foo3)
with g.name_scope("") as empty2:
- self.assertEqual(empty2, "")
+ self.assertEqual("", empty2)
self.assertEqual("const",
g.create_op("const", [], [dtypes.float32]).name)
@@ -792,6 +793,80 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
self.assertEqual(b.op.control_inputs, [])
+class OpScopeTest(test_util.TensorFlowTestCase):
+
+ def testNoScopeName(self):
+ g0 = ops.Graph()
+ values = [
+ g0.create_op("a", [], [dtypes.float32]),
+ g0.create_op("b", [], [dtypes.float32])]
+ with self.assertRaises(ValueError):
+ with ops.op_scope(values, None):
+ pass
+ with self.assertRaises(ValueError):
+ with ops.op_scope(values, None, None):
+ pass
+
+ def testEmptyScopeName(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [dtypes.float32])
+ b = g0.create_op("b", [], [dtypes.float32])
+ with ops.op_scope([a, b], "") as scope:
+ self.assertEqual("", scope)
+ self.assertEqual(g0, ops.get_default_graph())
+ with ops.op_scope([a, b], "", "my_default_scope") as scope:
+ self.assertEqual("", scope)
+ self.assertEqual(g0, ops.get_default_graph())
+
+ def testDefaultScopeName(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [dtypes.float32])
+ b = g0.create_op("b", [], [dtypes.float32])
+ scope_name = "my_scope"
+ default_scope_name = "my_default_scope"
+ with ops.op_scope([a, b], scope_name, default_scope_name) as scope:
+ self.assertEqual("%s/" % scope_name, scope)
+ self.assertEqual(g0, ops.get_default_graph())
+ with ops.op_scope([a, b], None, default_scope_name) as scope:
+ self.assertEqual("%s/" % default_scope_name, scope)
+ self.assertEqual(g0, ops.get_default_graph())
+
+ def _testGraphElements(self, graph_elements):
+ scope_name = "my_scope"
+ with ops.op_scope(graph_elements, scope_name) as scope:
+ self.assertEqual("%s/" % scope_name, scope)
+ self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
+ g1 = ops.Graph()
+ c = g1.create_op("c", [], [dtypes.float32])
+ with self.assertRaises(ValueError):
+ with ops.op_scope(graph_elements + [c], scope_name):
+ pass
+
+ def testTensor(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [dtypes.float32])
+ b = g0.create_op("b", [], [dtypes.float32])
+ self._testGraphElements([a, b])
+
+ def testSparseTensor(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [dtypes.float32])
+ b = g0.create_op("b", [], [dtypes.float32])
+ sparse = ops.SparseTensor(
+ _apply_op(g0, "const", [], [dtypes.int64]),
+ _apply_op(g0, "const", [], [dtypes.float32]),
+ _apply_op(g0, "const", [], [dtypes.int64]))
+ self._testGraphElements([a, sparse, b])
+
+ def testVariable(self):
+ g0 = ops.Graph()
+ with g0.as_default():
+ variable = variables.Variable([1.0])
+ a = g0.create_op("a", [], [dtypes.float32])
+ b = g0.create_op("b", [], [dtypes.float32])
+ self._testGraphElements([a, variable, b])
+
+
class GraphTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -835,27 +910,6 @@ class GraphTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError):
g.as_graph_element(NonConvertibleObj())
- def testAssertSameGraph(self):
- g0 = ops.Graph()
- a = g0.create_op("a", [], [dtypes.float32])
- b = g0.create_op("b", [], [dtypes.float32])
- ops.assert_same_graph([a, b])
- ops.assert_same_graph([a, b], g0)
- g1 = ops.Graph()
- c = g1.create_op("c", [], [dtypes.float32])
- self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c])
- self.assertRaises(ValueError, ops.assert_same_graph, [c], g0)
- self.assertRaises(ValueError, ops.assert_same_graph, [a], g1)
-
- sparse = ops.SparseTensor(
- _apply_op(g0, "const", [], [dtypes.int64]),
- _apply_op(g0, "const", [], [dtypes.float32]),
- _apply_op(g0, "const", [], [dtypes.int64]))
- ops.assert_same_graph([sparse, a, b])
- ops.assert_same_graph([sparse, a, b], g0)
- self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c])
- self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c], g1)
-
ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
index c2ad3bdb58..946d543b44 100644
--- a/tensorflow/python/ops/op_def_library.py
+++ b/tensorflow/python/ops/op_def_library.py
@@ -616,6 +616,10 @@ class OpDefLibrary(object):
elif attr_def.type == "list(tensor)":
attr_value.list.tensor.extend(
[_MakeTensor(x, key) for x in value])
+ elif attr_def.type == "func":
+ if not isinstance(value, compat.bytes_or_text_types):
+ raise TypeError("Expects a string for the func name")
+ attr_value.func.name = value
else:
raise TypeError("Unrecognized Attr type " + attr_def.type)