diff options
author | 2018-09-11 21:12:51 +0800 | |
---|---|---|
committer | 2018-09-11 21:12:51 +0800 | |
commit | be41397bd903c89c99fe63e61c0bf62870e062ea (patch) | |
tree | 2e2dce6e3d444edc16d9163508056941296311a1 /tensorflow/python/framework | |
parent | 204ef67242ce7fbba067b631c4d6c4bcd64288c2 (diff) | |
parent | 9fd56039064871a736bb7cff398b2a8e08454bee (diff) |
Merge branch 'master' into CLN/remove_print_for_assert
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r-- | tensorflow/python/framework/constant_op.py | 3 | ||||
-rw-r--r-- | tensorflow/python/framework/error_interpolation.py | 16 | ||||
-rw-r--r-- | tensorflow/python/framework/error_interpolation_test.py | 29 | ||||
-rw-r--r-- | tensorflow/python/framework/file_system_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/framework/function_test.py | 10 | ||||
-rw-r--r-- | tensorflow/python/framework/importer_test.py | 18 | ||||
-rw-r--r-- | tensorflow/python/framework/meta_graph_test.py | 9 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 7 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 50 | ||||
-rw-r--r-- | tensorflow/python/framework/python_op_gen_internal.cc | 13 | ||||
-rw-r--r-- | tensorflow/python/framework/sparse_tensor_test.py | 6 | ||||
-rw-r--r-- | tensorflow/python/framework/subscribe_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_shape.py | 4 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_util.py | 2 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_util_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util.py | 227 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util_test.py | 3 |
17 files changed, 206 insertions, 209 deletions
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index eca34ac26e..4b2706d4cf 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -105,7 +105,8 @@ def convert_to_eager_tensor(value, ctx, dtype=None): scalar_cache = ctx.scalar_cache() tensor = scalar_cache.get(cache_key, None) if tensor is not None: - return tensor + return ops.EagerTensor( + value, context=handle, device=device, dtype=dtype, other_value=tensor) t = ops.EagerTensor(value, context=handle, device=device, dtype=dtype) scalar_cache[cache_key] = t return t diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py index a69018d00d..bc3c81b2a2 100644 --- a/tensorflow/python/framework/error_interpolation.py +++ b/tensorflow/python/framework/error_interpolation.py @@ -15,7 +15,7 @@ """Function for interpolating formatted errors from the TensorFlow runtime. Exposes the function `interpolate` to interpolate messages with tags of the form -^^type:name:format^^. +{{type name}}. """ from __future__ import absolute_import @@ -32,9 +32,9 @@ import six from tensorflow.python.util import tf_stack _NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?" -_TAG_REGEX = r"\^\^({name}):({name})\^\^".format(name=_NAME_REGEX) +_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX) _INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX) -_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX) +_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL) _ParseTag = collections.namedtuple("_ParseTag", ["type", "name"]) @@ -48,8 +48,8 @@ def _parse_message(message): """Parses the message. Splits the message into separators and tags. Tags are named tuples - representing the string ^^type:name^^ and they are separated by - separators. For example, in "123^^node:Foo^^456^^node:Bar^^789", there are + representing the string {{type name}} and they are separated by + separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are two tags and three separators. The separators are the numeric characters. Args: @@ -58,7 +58,7 @@ def _parse_message(message): Returns: (list of separator strings, list of _ParseTags). - For example, if message is "123^^node:Foo^^456" then this function + For example, if message is "123{{node Foo}}456" then this function returns (["123", "456"], [_ParseTag("node", "Foo")]) """ seps = [] @@ -276,7 +276,7 @@ def interpolate(error_message, graph): message. Returns: - The string with tags of the form ^^type:name^^ interpolated. + The string with tags of the form {{type name}} interpolated. """ seps, tags = _parse_message(error_message) subs = [] @@ -288,7 +288,7 @@ def interpolate(error_message, graph): except KeyError: op = None - msg = "^^%s:%s^^" % (t.type, t.name) + msg = "{{%s %s}}" % (t.type, t.name) if op is not None: field_dict = compute_field_dict(op) if t.type == "node": diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py index a7c7bbf28b..1b77548592 100644 --- a/tensorflow/python/framework/error_interpolation_test.py +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -167,26 +167,31 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase): self.assertEqual(interpolated_string, normal_string) def testOneTagWithAFakeNameResultsInPlaceholders(self): - one_tag_string = "^^node:MinusOne^^" + one_tag_string = "{{node MinusOne}}" interpolated_string = error_interpolation.interpolate( one_tag_string, self.graph) self.assertEqual(one_tag_string, interpolated_string) def testTwoTagsNoSeps(self): - two_tags_no_seps = "^^node:One^^^^node:Three^^" + two_tags_no_seps = "{{node One}}{{node Three}}" interpolated_string = error_interpolation.interpolate( two_tags_no_seps, self.graph) self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*constant_op.py:[0-9]+") def testTwoTagsWithSeps(self): - two_tags_with_seps = ";;;^^node:Two^^,,,^^node:Three^^;;;" + two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;" interpolated_string = error_interpolation.interpolate( two_tags_with_seps, self.graph) expected_regex = ( - r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]*\) ;;;$") + r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$") self.assertRegexpMatches(interpolated_string, expected_regex) + def testNewLine(self): + newline = "\n\n{{node One}}" + interpolated_string = error_interpolation.interpolate(newline, self.graph) + self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*") + class InterpolateDeviceSummaryTest(test.TestCase): @@ -206,23 +211,23 @@ class InterpolateDeviceSummaryTest(test.TestCase): self.graph = self.three.graph def testNodeZeroHasNoDeviceSummaryInfo(self): - message = "^^colocation_node:zero^^" + message = "{{colocation_node zero}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("No device assignments were active", result) def testNodeOneHasExactlyOneInterpolatedDevice(self): - message = "^^colocation_node:one^^" + message = "{{colocation_node one}}" result = error_interpolation.interpolate(message, self.graph) self.assertEqual(2, result.count("tf.device(/cpu)")) def testNodeTwoHasTwoInterpolatedDevice(self): - message = "^^colocation_node:two^^" + message = "{{colocation_node two}}" result = error_interpolation.interpolate(message, self.graph) self.assertEqual(2, result.count("tf.device(/cpu)")) self.assertEqual(2, result.count("tf.device(/cpu:0)")) def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self): - message = "^^colocation_node:three^^" + message = "{{colocation_node three}}" result = error_interpolation.interpolate(message, self.graph) num_devices = result.count("tf.device") self.assertEqual(2, num_devices) @@ -256,12 +261,12 @@ class InterpolateColocationSummaryTest(test.TestCase): self.graph = node_three.graph def testNodeThreeHasColocationInterpolation(self): - message = "^^colocation_node:Three_with_one^^" + message = "{{colocation_node Three_with_one}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("colocate_with(One)", result) def testNodeFourHasColocationInterpolationForNodeThreeOnly(self): - message = "^^colocation_node:Four_with_three^^" + message = "{{colocation_node Four_with_three}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("colocate_with(Three_with_one)", result) self.assertNotIn( @@ -269,13 +274,13 @@ class InterpolateColocationSummaryTest(test.TestCase): "Node One should not appear in Four_with_three's summary:\n%s" % result) def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self): - message = "^^colocation_node:Five_with_one_with_two^^" + message = "{{colocation_node Five_with_one_with_two}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("colocate_with(One)", result) self.assertIn("colocate_with(Two)", result) def testColocationInterpolationForNodeLackingColocation(self): - message = "^^colocation_node:One^^" + message = "{{colocation_node One}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("No node-device colocations", result) self.assertNotIn("Two", result) diff --git a/tensorflow/python/framework/file_system_test.py b/tensorflow/python/framework/file_system_test.py index 5eb59141a2..6901715e5d 100644 --- a/tensorflow/python/framework/file_system_test.py +++ b/tensorflow/python/framework/file_system_test.py @@ -37,7 +37,7 @@ class FileSystemTest(test.TestCase): load_library.load_file_system_library(file_system_library) def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: reader = io_ops.WholeFileReader("test_reader") queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) queue.enqueue_many([["test://foo"]]).run() diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index ee723bacaf..903768a039 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -419,7 +419,7 @@ class FunctionTest(test.TestCase): with ops.control_dependencies([z]): return x * 2 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): z = Foo(constant_op.constant(3.0)) self.assertAllEqual(z.eval(), 6.0) @@ -434,7 +434,7 @@ class FunctionTest(test.TestCase): # Foo contains a stateful op (Assert). self.assertEqual([("Assert", "Assert")], Foo.stateful_ops) g = ops.Graph() - with g.as_default(), self.test_session(): + with g.as_default(), self.cached_session(): self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "assertion failed.*-3"): @@ -448,7 +448,7 @@ class FunctionTest(test.TestCase): [control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]): return array_ops.identity(x) - with self.test_session(): + with self.cached_session(): self.assertEqual(1.0, MyFn(1.0).eval()) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "assertion"): @@ -667,7 +667,7 @@ class FunctionTest(test.TestCase): with ops.Graph().as_default(): z = CubeXPlusY(3.0, -2.0) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(z.eval(), 25.0) def testNestedDefinedFunction(self): @@ -683,7 +683,7 @@ class FunctionTest(test.TestCase): with ops.Graph().as_default(): z = CubeXPlusY(3.0, -2.0) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(z.eval(), 25.0) def testUnusedFunction(self): diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index 18e7d8aa14..2b4d8e7299 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -396,7 +396,7 @@ class ImportGraphDefTest(test.TestCase): # Run the imported graph. # TODO(b/76173421): make this work (currently DCHECKS) - # with self.test_session() as sess: + # with self.cached_session() as sess: # sess.run(imported_init) # self.assertEqual(sess.run(imported_var), 1.0) # self.assertEqual(sess.run(imported_assign), 2.0) @@ -417,7 +417,7 @@ class ImportGraphDefTest(test.TestCase): imported_r, = importer.import_graph_def(graph_def, return_elements=[r.name]) self.assertEqual(imported_r.name, "import/" + r.name) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(sess.run(imported_r), 10) def testImportWhileLoopInCond(self): @@ -436,7 +436,7 @@ class ImportGraphDefTest(test.TestCase): pred = array_ops.placeholder(dtypes.bool) out = control_flow_ops.cond(pred, ImportFn, lambda: constant_op.constant(1)) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(sess.run(out, {pred: True}), 10) self.assertEqual(sess.run(out, {pred: False}), 1) @@ -457,7 +457,7 @@ class ImportGraphDefTest(test.TestCase): out = control_flow_ops.while_loop( lambda i: i < 2, ImportFn, [0], shape_invariants=[tensor_shape.TensorShape(None)]) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(sess.run(out), 10) def testTypeMismatchInGraphDef(self): @@ -929,7 +929,7 @@ class ImportGraphDefTest(test.TestCase): input_map={"a:0": constant_op.constant(5.0)}, name="", return_elements=["id:0"]) - with self.test_session(): + with self.cached_session(): self.assertEqual(5.0, t.eval()) def testInvalidInputForReturnOperations(self): @@ -958,7 +958,7 @@ class ImportGraphDefTest(test.TestCase): array_ops.stack([c, c], name="pack") gdef = g.as_graph_def() - with self.test_session(): + with self.cached_session(): pack, = importer.import_graph_def(gdef, return_elements=["pack"]) self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0]) @@ -1063,7 +1063,7 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual([10], biases_grad.get_shape()) def testLargeGraph(self): - with self.test_session(): + with self.cached_session(): # The default message byte limit is 64M. Ours is 2G with a warning at 512. # Adding a 130M entries float32 tensor should exceed the warning, but not # the hard limit. @@ -1254,7 +1254,7 @@ class ImportGraphDefTest(test.TestCase): z = TestFunc() - with self.test_session(): + with self.cached_session(): z_val = z.eval() self.assertEqual(z_val, -2.0) @@ -1284,7 +1284,7 @@ class ImportGraphDefTest(test.TestCase): z2 = importer.import_graph_def(gdef, return_elements=["z:0"], input_map=input_map)[0] - with self.test_session() as sess: + with self.cached_session() as sess: z1_val, z2_val = sess.run((z1, z2)) self.assertAllEqual(z1_val, z2_val) diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 6e5f7aafac..fc98b91a01 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -117,7 +117,7 @@ class SimpleMetaGraphTest(test.TestCase): self.assertEqual(new_output_value, output_value) def testStrippedOpListNestedFunctions(self): - with self.test_session(): + with self.cached_session(): # Square two levels deep @function.Defun(dtypes.int32) def f0(x): @@ -169,7 +169,7 @@ class SimpleMetaGraphTest(test.TestCase): # and "Tout" maps to complex64. Since these attr values map to their # defaults, they must be stripped unless stripping of default attrs is # disabled. - with self.test_session(): + with self.cached_session(): real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real") imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") @@ -212,7 +212,8 @@ class SimpleMetaGraphTest(test.TestCase): def testDefaultAttrStrippingNestedFunctions(self): """Verifies that default attributes are stripped from function node defs.""" - with self.test_session(): + with self.cached_session(): + @function.Defun(dtypes.float32, dtypes.float32) def f0(i, j): return math_ops.complex(i, j, name="double_nested_complex") @@ -251,7 +252,7 @@ class SimpleMetaGraphTest(test.TestCase): meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() meta_info_def.stripped_op_list.op.add() - with self.test_session(): + with self.cached_session(): meta_graph_def = meta_graph.create_meta_graph_def( meta_info_def=meta_info_def, graph_def=graph_def, strip_default_attrs=True) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 4cfd639bf9..75678cbc01 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -55,6 +55,7 @@ from tensorflow.python.platform import app from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils +from tensorflow.python.util import deprecation from tensorflow.python.util import function_utils from tensorflow.python.util import lock_util from tensorflow.python.util import tf_contextlib @@ -5363,6 +5364,7 @@ def enable_eager_execution(config=None, computational graph). For example: + ```python tf.enable_eager_execution() @@ -5807,11 +5809,8 @@ class GraphKeys(object): _STREAMING_MODEL_PORTS = "streaming_model_ports" @decorator_utils.classproperty + @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.") def VARIABLES(cls): # pylint: disable=no-self-argument - logging.log_first_n(logging.WARN, - "VARIABLES collection name is deprecated, please use " - "GLOBAL_VARIABLES instead; VARIABLES will be removed " - "after 2017-03-02.", 1) return cls.GLOBAL_VARIABLES diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index ced0581402..d59adf3d48 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -58,12 +58,12 @@ ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn) class ResourceTest(test_util.TensorFlowTestCase): def testBuildGraph(self): - with self.test_session(): + with self.cached_session(): pt = test_ops.stub_resource_handle_op(container="a", shared_name="b") test_ops.resource_create_op(pt).run() def testInitialize(self): - with self.test_session(): + with self.cached_session(): handle = test_ops.stub_resource_handle_op(container="a", shared_name="b") resources.register_resource( handle=handle, @@ -100,35 +100,35 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): pass def testAddShape(self): - with self.test_session(): + with self.cached_session(): a = array_ops.zeros([2, 3]) b = array_ops.ones([1, 3]) c = a + b self.assertEqual([2, 3], c.shape) def testUnknownDim(self): - with self.test_session(): + with self.cached_session(): a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) c = a + b self.assertEqual([2, None, 3], c.shape.as_list()) def testUnknownShape(self): - with self.test_session(): + with self.cached_session(): a = array_ops.placeholder(dtype=dtypes.float32, shape=None) b = array_ops.ones([1, 3]) c = a + b self.assertEqual(tensor_shape.unknown_shape(), c.shape) def testScalarShape(self): - with self.test_session(): + with self.cached_session(): a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) b = array_ops.ones([]) c = a + b self.assertEqual(tensor_shape.scalar(), c.shape) def testShapeFunctionError(self): - with self.test_session(): + with self.cached_session(): a = array_ops.ones([1, 2, 3]) b = array_ops.ones([4, 5, 6]) with self.assertRaisesRegexp( @@ -141,7 +141,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): class IndexedSlicesTest(test_util.TensorFlowTestCase): def testToTensor(self): - with self.test_session(): + with self.cached_session(): values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) indices = constant_op.constant([0, 2]) dense_shape = constant_op.constant([3, 2]) @@ -150,7 +150,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase): self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]]) def testNegation(self): - with self.test_session(): + with self.cached_session(): values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) indices = constant_op.constant([0, 2]) x = -ops.IndexedSlices(values, indices) @@ -158,7 +158,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase): self.assertAllEqual(x.indices.eval(), [0, 2]) def testScalarMul(self): - with self.test_session(): + with self.cached_session(): values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) indices = constant_op.constant([0, 2]) x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices)) @@ -307,14 +307,14 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertEqual(tensor_shape.unknown_shape(), op.get_shape()) def testConvertToTensorNestedArray(self): - with self.test_session(): + with self.cached_session(): values = [[2], [3], [5], [7]] tensor = ops.convert_to_tensor(values) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(values, tensor.eval()) def testShapeTuple(self): - with self.test_session(): + with self.cached_session(): c = constant_op.constant(1) self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access @@ -328,14 +328,14 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(converted, ops.EagerTensor)) def testConvertToTensorNestedTuple(self): - with self.test_session(): + with self.cached_session(): values = ((2,), (3,), (5,), (7,)) tensor = ops.convert_to_tensor(values) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(values, ops.convert_to_tensor(values).eval()) def testConvertToTensorNestedTensors(self): - with self.test_session(): + with self.cached_session(): values = ((2,), (3,), (5,), (7,)) tensor = ops.convert_to_tensor( [constant_op.constant(row) for row in values]) @@ -347,25 +347,25 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertAllEqual(values, tensor.eval()) def testConvertToTensorNestedMix(self): - with self.test_session(): + with self.cached_session(): values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7])) tensor = ops.convert_to_tensor(values) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval()) def testConvertToTensorPreferred(self): - with self.test_session(): + with self.cached_session(): values = [2, 3, 5, 7] tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) self.assertEqual(dtypes.float32, tensor.dtype) - with self.test_session(): + with self.cached_session(): # Convert empty tensor to anything. values = [] tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) self.assertEqual(dtypes.int64, tensor.dtype) - with self.test_session(): + with self.cached_session(): # The preferred dtype is a type error and will convert to # float32 instead. values = [1.23] @@ -941,7 +941,7 @@ class NameStackTest(test_util.TensorFlowTestCase): self.assertEqual("bar_2", g.unique_name("bar")) def testNameAndVariableScope(self): - with self.test_session() as sess: + with self.cached_session() as sess: with sess.graph.name_scope("l0"): with variable_scope.variable_scope("l1"): with sess.graph.name_scope("l1") as scope: @@ -2164,7 +2164,7 @@ class InitScopeTest(test_util.TensorFlowTestCase): g = ops.Graph() with g.as_default(): - with self.test_session(): + with self.cached_session(): # First ensure that graphs that are not building functions are # not escaped. function_with_variables("foo") @@ -2416,11 +2416,11 @@ class AttrScopeTest(test_util.TensorFlowTestCase): return (a, b) def testNoLabel(self): - with self.test_session(): + with self.cached_session(): self.assertAllEqual((None, None), self._get_test_attrs()) def testLabelMap(self): - with self.test_session() as sess: + with self.cached_session() as sess: a1 = self._get_test_attrs() with sess.graph._attr_scope({ "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo")) @@ -2454,12 +2454,12 @@ ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) class KernelLabelTest(test_util.TensorFlowTestCase): def testNoLabel(self): - with self.test_session(): + with self.cached_session(): self.assertAllEqual(b"My label is: default", test_ops.kernel_label().eval()) def testLabelMap(self): - with self.test_session() as sess: + with self.cached_session() as sess: default_1 = test_ops.kernel_label() # pylint: disable=protected-access with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): @@ -2900,7 +2900,7 @@ class NameScopeTest(test_util.TensorFlowTestCase): class TracebackTest(test_util.TensorFlowTestCase): def testTracebackWithStartLines(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = constant_op.constant(2.0) sess.run( a, diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index f2270342b0..f6aef5bc50 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -15,18 +15,20 @@ limitations under the License. #include "tensorflow/python/framework/python_op_gen_internal.h" +#include <float.h> #include <stdio.h> +#include <iomanip> #include <sstream> #include <unordered_map> #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_gen_lib.h" -#include "tensorflow/core/framework/tensor.pb_text.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor.pb_text.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" @@ -435,7 +437,12 @@ string AttrValueToPython(const string& type, const AttrValue& value, if (std::isnan(value.f()) || std::isinf(value.f())) { return strings::StrCat("float('", value.f(), "')"); } else { - return strings::StrCat(value.f()); + // Use locale-independent conversion. + static_assert(FLT_DIG < 10, "FLT_DIG is too big"); + std::ostringstream s; + s.imbue(std::locale::classic()); + s << std::setprecision(FLT_DIG) << value.f(); + return s.str(); } } else if (type == "bool") { return value.b() ? "True" : "False"; diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index 2bcfbc17df..22423c4f58 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -45,7 +45,7 @@ class SparseTensorTest(test_util.TensorFlowTestCase): self.assertEqual(sp.dense_shape.dtype, dtypes.int64) self.assertEqual(sp.get_shape(), (4, 5)) - with self.test_session() as sess: + with self.cached_session() as sess: value = sp.eval() self.assertAllEqual(indices, value.indices) self.assertAllEqual(values, value.values) @@ -81,14 +81,14 @@ class SparseTensorTest(test_util.TensorFlowTestCase): class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase): def test_convert_dense(self): - with self.test_session(): + with self.cached_session(): value = [42, 43] from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor( value) self.assertAllEqual(value, from_value.eval()) def test_convert_sparse(self): - with self.test_session(): + with self.cached_session(): indices = [[0, 1], [1, 0]] values = [42, 43] shape = [2, 2] diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py index d6de45fdc4..1d594e4078 100644 --- a/tensorflow/python/framework/subscribe_test.py +++ b/tensorflow/python/framework/subscribe_test.py @@ -65,7 +65,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): self.assertFalse(c0.op in d.op.control_inputs) self.assertTrue(c.op in d.op.control_inputs) - with self.test_session() as sess: + with self.cached_session() as sess: c_out = sess.run([c]) n_out = sess.run([n]) d_out = sess.run([d]) @@ -144,7 +144,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): b = subscribe.subscribe(b, lambda t: script_ops.py_func(sub, [t], [t.dtype])) - with self.test_session() as sess: + with self.cached_session() as sess: c_out = sess.run([c]) d_out = sess.run([d]) @@ -204,7 +204,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): self.assertIs(c_sub, c_sub3) # Expect the three side effect graphs to have been evaluated. - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([c_sub]) self.assertIn('graph1', shared) self.assertIn('graph2', shared) @@ -227,7 +227,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): v1, lambda t: script_ops.py_func(sub, [t], [t.dtype])) self.assertTrue(subscribe._is_subscribed_identity(v1_sub)) - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize the variables first. sess.run([v1.initializer]) sess.run([v2.initializer]) @@ -272,7 +272,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): self.assertIs(tensor_array_sub, tensor_array.handle) self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([reader]) self.assertEqual(0, len(shared)) @@ -303,7 +303,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): subscribe.subscribe(sparse_add.op.outputs, lambda t: script_ops.py_func(sub, [t], [t.dtype])) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([neg]) # All three ops have been processed. @@ -374,7 +374,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): # Verify that sub(x1) and sub(branch) are not. self.assertIsNot(context(subscriptions[0]), context(subscriptions[1])) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(cond) self.assertEqual(3, len(results)) diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 11b681d544..3c2a736fb9 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -606,8 +606,8 @@ class TensorShape(object): slice. Raises: - ValueError: If `key` is a slice, and any of its elements are negative, or - if `self` is completely unknown and the step is set. + ValueError: If `key` is a slice and `self` is completely unknown and + the step is set. """ if self._dims is not None: if isinstance(key, slice): diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index b14290c203..26170b000d 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -367,7 +367,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): A `TensorProto`. Depending on the type, it may contain data in the "tensor_content" attribute, which is not directly useful to Python programs. To access the values you should convert the proto back to a numpy ndarray - with `tensor_util.MakeNdarray(proto)`. + with `tf.make_ndarray(proto)`. If `values` is a `TensorProto`, it is immediately returned; `dtype` and `shape` are ignored. diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 395cf43b3f..bdf759f220 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -768,7 +768,7 @@ class TensorUtilTest(test.TestCase): def __array__(self, dtype=None): return np.asarray(self.array, dtype) - with self.test_session() as sess: + with self.cached_session() as sess: ma = MockArray(np.array([10, 20, 30])) t = ops.convert_to_tensor(ma) a = sess.run(t) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 6d03e956da..b34330aa2a 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -465,29 +465,31 @@ def assert_no_new_pyobjects_executing_eagerly(f): f(self, **kwargs) gc.collect() previous_count = len(gc.get_objects()) - collection_sizes_before = { - collection: len(ops.get_collection(collection)) - for collection in ops.get_default_graph().collections - } + if ops.has_default_graph(): + collection_sizes_before = { + collection: len(ops.get_collection(collection)) + for collection in ops.get_default_graph().collections + } for _ in range(3): f(self, **kwargs) # Note that gc.get_objects misses anything that isn't subject to garbage # collection (C types). Collections are a common source of leaks, so we # test for collection sizes explicitly. - for collection_key in ops.get_default_graph().collections: - collection = ops.get_collection(collection_key) - size_before = collection_sizes_before.get(collection_key, 0) - if len(collection) > size_before: - raise AssertionError( - ("Collection %s increased in size from " - "%d to %d (current items %s).") % (collection_key, size_before, - len(collection), collection)) - # Make sure our collection checks don't show up as leaked memory by - # removing references to temporary variables. - del collection - del collection_key - del size_before - del collection_sizes_before + if ops.has_default_graph(): + for collection_key in ops.get_default_graph().collections: + collection = ops.get_collection(collection_key) + size_before = collection_sizes_before.get(collection_key, 0) + if len(collection) > size_before: + raise AssertionError( + ("Collection %s increased in size from " + "%d to %d (current items %s).") % + (collection_key, size_before, len(collection), collection)) + # Make sure our collection checks don't show up as leaked memory by + # removing references to temporary variables. + del collection + del collection_key + del size_before + del collection_sizes_before gc.collect() # There should be no new Python objects hanging around. new_count = len(gc.get_objects()) @@ -535,15 +537,16 @@ def assert_no_new_tensors(f): tensors_before = set( id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) - if context.executing_eagerly(): - f(self, **kwargs) - ops.reset_default_graph() - else: - # Run the test in a new graph so that collections get cleared when it's - # done, but inherit the graph key so optimizers behave. - outside_graph_key = ops.get_default_graph()._graph_key - with ops.Graph().as_default(): - ops.get_default_graph()._graph_key = outside_graph_key + outside_executed_eagerly = context.executing_eagerly() + # Run the test in a new graph so that collections get cleared when it's + # done, but inherit the graph key so optimizers behave. + outside_graph_key = ops.get_default_graph()._graph_key + with ops.Graph().as_default(): + ops.get_default_graph()._graph_key = outside_graph_key + if outside_executed_eagerly: + with context.eager_mode(): + f(self, **kwargs) + else: f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. @@ -1072,13 +1075,9 @@ class TensorFlowTestCase(googletest.TestCase): if context.executing_eagerly(): yield None else: - sess = self._create_session(graph, config, use_gpu, force_gpu) - with self._constrain_devices_and_set_default( - sess, use_gpu, force_gpu) as constrained_sess: - # We need to do this to make sure the session closes, otherwise, even - # if the user does with self.session():, it will not close the session. - with constrained_sess: - yield constrained_sess + with self._create_session(graph, config, force_gpu) as sess: + with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu): + yield sess @contextlib.contextmanager def cached_session(self, @@ -1126,10 +1125,11 @@ class TensorFlowTestCase(googletest.TestCase): if context.executing_eagerly(): yield None else: - with self._get_cached_session( - graph, config, use_gpu, force_gpu, - crash_if_inconsistent_args=True) as sess: - yield sess + sess = self._get_cached_session( + graph, config, force_gpu, crash_if_inconsistent_args=True) + with self._constrain_devices_and_set_default(sess, use_gpu, + force_gpu) as cached: + yield cached @contextlib.contextmanager def test_session(self, @@ -1145,10 +1145,11 @@ class TensorFlowTestCase(googletest.TestCase): yield None else: if graph is None: - with self._get_cached_session( - graph, config, use_gpu, force_gpu, - crash_if_inconsistent_args=False) as sess: - yield sess + sess = self._get_cached_session( + graph, config, force_gpu, crash_if_inconsistent_args=False) + with self._constrain_devices_and_set_default(sess, use_gpu, + force_gpu) as cached: + yield cached else: with self.session(graph, config, use_gpu, force_gpu) as sess: yield sess @@ -1326,9 +1327,17 @@ class TensorFlowTestCase(googletest.TestCase): def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): a = self._GetNdArray(a) b = self._GetNdArray(b) - self.assertEqual( - a.shape, b.shape, - "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)) + # When the array rank is small, print its contents. Numpy array printing is + # implemented using inefficient recursion so prints can cause tests to + # time out. + if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): + shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " + "%s.") % (a.shape, b.shape, b) + else: + shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, + b.shape) + self.assertEqual(a.shape, b.shape, shape_mismatch_msg) + msgs = [msg] if not np.allclose(a, b, rtol=rtol, atol=atol): # Add more details than np.testing.assert_allclose. @@ -1836,91 +1845,69 @@ class TensorFlowTestCase(googletest.TestCase): with sess.graph.device("/cpu:0"): yield sess - def _create_session(self, graph, config, use_gpu, force_gpu): + def _create_session(self, graph, config, force_gpu): """See session() for details.""" - if context.executing_eagerly(): - return None - else: + def prepare_config(config): + """Returns a config for sessions. - def prepare_config(config): - """Returns a config for sessions. - - Args: - config: An optional config_pb2.ConfigProto to use to configure the - session. - Returns: - A config_pb2.ConfigProto object. - """ - if config is None: - config = config_pb2.ConfigProto() - config.allow_soft_placement = not force_gpu - config.gpu_options.per_process_gpu_memory_fraction = 0.3 - elif force_gpu and config.allow_soft_placement: - config = config_pb2.ConfigProto().CopyFrom(config) - config.allow_soft_placement = False - # Don't perform optimizations for tests so we don't inadvertently run - # gpu ops on cpu - config.graph_options.optimizer_options.opt_level = -1 - config.graph_options.rewrite_options.constant_folding = ( - rewriter_config_pb2.RewriterConfig.OFF) - config.graph_options.rewrite_options.arithmetic_optimization = ( - rewriter_config_pb2.RewriterConfig.OFF) - return config - - return ErrorLoggingSession(graph=graph, config=prepare_config(config)) + Args: + config: An optional config_pb2.ConfigProto to use to configure the + session. + + Returns: + A config_pb2.ConfigProto object. + """ + if config is None: + config = config_pb2.ConfigProto() + config.allow_soft_placement = not force_gpu + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + elif force_gpu and config.allow_soft_placement: + config = config_pb2.ConfigProto().CopyFrom(config) + config.allow_soft_placement = False + # Don't perform optimizations for tests so we don't inadvertently run + # gpu ops on cpu + config.graph_options.optimizer_options.opt_level = -1 + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + config.graph_options.rewrite_options.arithmetic_optimization = ( + rewriter_config_pb2.RewriterConfig.OFF) + return config + + return ErrorLoggingSession(graph=graph, config=prepare_config(config)) - @contextlib.contextmanager def _get_cached_session(self, graph=None, config=None, - use_gpu=False, force_gpu=False, crash_if_inconsistent_args=True): """See cached_session() for documentation.""" - if context.executing_eagerly(): - yield None + if self._cached_session is None: + sess = self._create_session( + graph=graph, config=config, force_gpu=force_gpu) + self._cached_session = sess + self._cached_graph = graph + self._cached_config = config + self._cached_force_gpu = force_gpu + return sess else: - if self._cached_session is None: - sess = self._create_session( - graph=graph, config=config, use_gpu=use_gpu, force_gpu=force_gpu) - self._cached_session = sess - self._cached_graph = graph - self._cached_config = config - self._cached_use_gpu = use_gpu - self._cached_force_gpu = force_gpu - with self._constrain_devices_and_set_default( - sess, use_gpu, force_gpu) as constrained_sess: - yield constrained_sess - else: - if crash_if_inconsistent_args and self._cached_graph is not graph: - raise ValueError("The graph used to get the cached session is " - "different than the one that was used to create the " - "session. Maybe create a new session with " - "self.session()") - if crash_if_inconsistent_args and self._cached_config is not config: - raise ValueError("The config used to get the cached session is " - "different than the one that was used to create the " - "session. Maybe create a new session with " - "self.session()") - if crash_if_inconsistent_args and self._cached_use_gpu is not use_gpu: - raise ValueError( - "The use_gpu value used to get the cached session is " - "different than the one that was used to create the " - "session. Maybe create a new session with " - "self.session()") - if crash_if_inconsistent_args and (self._cached_force_gpu is - not force_gpu): - raise ValueError( - "The force_gpu value used to get the cached session is " - "different than the one that was used to create the " - "session. Maybe create a new session with " - "self.session()") - # If you modify this logic, make sure to modify it in _create_session - # as well. - sess = self._cached_session - with self._constrain_devices_and_set_default( - sess, use_gpu, force_gpu) as constrained_sess: - yield constrained_sess + if crash_if_inconsistent_args and self._cached_graph is not graph: + raise ValueError("The graph used to get the cached session is " + "different than the one that was used to create the " + "session. Maybe create a new session with " + "self.session()") + if crash_if_inconsistent_args and self._cached_config is not config: + raise ValueError("The config used to get the cached session is " + "different than the one that was used to create the " + "session. Maybe create a new session with " + "self.session()") + if crash_if_inconsistent_args and (self._cached_force_gpu is + not force_gpu): + raise ValueError( + "The force_gpu value used to get the cached session is " + "different than the one that was used to create the " + "session. Maybe create a new session with " + "self.session()") + return self._cached_session @tf_export("test.create_local_cluster") diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index c9b5d46f98..22189afa59 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -71,9 +71,6 @@ class TestUtilTest(test_util.TensorFlowTestCase): with self.cached_session(graph=ops.Graph()) as sess2: pass with self.assertRaises(ValueError): - with self.cached_session(use_gpu=True) as sess2: - pass - with self.assertRaises(ValueError): with self.cached_session(force_gpu=True) as sess2: pass # We make sure that test_session will cache the session even after the |