aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-11 21:12:51 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-11 21:12:51 +0800
commitbe41397bd903c89c99fe63e61c0bf62870e062ea (patch)
tree2e2dce6e3d444edc16d9163508056941296311a1 /tensorflow/python/framework
parent204ef67242ce7fbba067b631c4d6c4bcd64288c2 (diff)
parent9fd56039064871a736bb7cff398b2a8e08454bee (diff)
Merge branch 'master' into CLN/remove_print_for_assert
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/constant_op.py3
-rw-r--r--tensorflow/python/framework/error_interpolation.py16
-rw-r--r--tensorflow/python/framework/error_interpolation_test.py29
-rw-r--r--tensorflow/python/framework/file_system_test.py2
-rw-r--r--tensorflow/python/framework/function_test.py10
-rw-r--r--tensorflow/python/framework/importer_test.py18
-rw-r--r--tensorflow/python/framework/meta_graph_test.py9
-rw-r--r--tensorflow/python/framework/ops.py7
-rw-r--r--tensorflow/python/framework/ops_test.py50
-rw-r--r--tensorflow/python/framework/python_op_gen_internal.cc13
-rw-r--r--tensorflow/python/framework/sparse_tensor_test.py6
-rw-r--r--tensorflow/python/framework/subscribe_test.py14
-rw-r--r--tensorflow/python/framework/tensor_shape.py4
-rw-r--r--tensorflow/python/framework/tensor_util.py2
-rw-r--r--tensorflow/python/framework/tensor_util_test.py2
-rw-r--r--tensorflow/python/framework/test_util.py227
-rw-r--r--tensorflow/python/framework/test_util_test.py3
17 files changed, 206 insertions, 209 deletions
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index eca34ac26e..4b2706d4cf 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -105,7 +105,8 @@ def convert_to_eager_tensor(value, ctx, dtype=None):
scalar_cache = ctx.scalar_cache()
tensor = scalar_cache.get(cache_key, None)
if tensor is not None:
- return tensor
+ return ops.EagerTensor(
+ value, context=handle, device=device, dtype=dtype, other_value=tensor)
t = ops.EagerTensor(value, context=handle, device=device, dtype=dtype)
scalar_cache[cache_key] = t
return t
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
index a69018d00d..bc3c81b2a2 100644
--- a/tensorflow/python/framework/error_interpolation.py
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -15,7 +15,7 @@
"""Function for interpolating formatted errors from the TensorFlow runtime.
Exposes the function `interpolate` to interpolate messages with tags of the form
-^^type:name:format^^.
+{{type name}}.
"""
from __future__ import absolute_import
@@ -32,9 +32,9 @@ import six
from tensorflow.python.util import tf_stack
_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
-_TAG_REGEX = r"\^\^({name}):({name})\^\^".format(name=_NAME_REGEX)
+_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX)
_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
-_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX)
+_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
@@ -48,8 +48,8 @@ def _parse_message(message):
"""Parses the message.
Splits the message into separators and tags. Tags are named tuples
- representing the string ^^type:name^^ and they are separated by
- separators. For example, in "123^^node:Foo^^456^^node:Bar^^789", there are
+ representing the string {{type name}} and they are separated by
+ separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are
two tags and three separators. The separators are the numeric characters.
Args:
@@ -58,7 +58,7 @@ def _parse_message(message):
Returns:
(list of separator strings, list of _ParseTags).
- For example, if message is "123^^node:Foo^^456" then this function
+ For example, if message is "123{{node Foo}}456" then this function
returns (["123", "456"], [_ParseTag("node", "Foo")])
"""
seps = []
@@ -276,7 +276,7 @@ def interpolate(error_message, graph):
message.
Returns:
- The string with tags of the form ^^type:name^^ interpolated.
+ The string with tags of the form {{type name}} interpolated.
"""
seps, tags = _parse_message(error_message)
subs = []
@@ -288,7 +288,7 @@ def interpolate(error_message, graph):
except KeyError:
op = None
- msg = "^^%s:%s^^" % (t.type, t.name)
+ msg = "{{%s %s}}" % (t.type, t.name)
if op is not None:
field_dict = compute_field_dict(op)
if t.type == "node":
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py
index a7c7bbf28b..1b77548592 100644
--- a/tensorflow/python/framework/error_interpolation_test.py
+++ b/tensorflow/python/framework/error_interpolation_test.py
@@ -167,26 +167,31 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
self.assertEqual(interpolated_string, normal_string)
def testOneTagWithAFakeNameResultsInPlaceholders(self):
- one_tag_string = "^^node:MinusOne^^"
+ one_tag_string = "{{node MinusOne}}"
interpolated_string = error_interpolation.interpolate(
one_tag_string, self.graph)
self.assertEqual(one_tag_string, interpolated_string)
def testTwoTagsNoSeps(self):
- two_tags_no_seps = "^^node:One^^^^node:Three^^"
+ two_tags_no_seps = "{{node One}}{{node Three}}"
interpolated_string = error_interpolation.interpolate(
two_tags_no_seps, self.graph)
self.assertRegexpMatches(interpolated_string,
"constant_op.py:[0-9]+.*constant_op.py:[0-9]+")
def testTwoTagsWithSeps(self):
- two_tags_with_seps = ";;;^^node:Two^^,,,^^node:Three^^;;;"
+ two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, self.graph)
expected_regex = (
- r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]*\) ;;;$")
+ r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$")
self.assertRegexpMatches(interpolated_string, expected_regex)
+ def testNewLine(self):
+ newline = "\n\n{{node One}}"
+ interpolated_string = error_interpolation.interpolate(newline, self.graph)
+ self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*")
+
class InterpolateDeviceSummaryTest(test.TestCase):
@@ -206,23 +211,23 @@ class InterpolateDeviceSummaryTest(test.TestCase):
self.graph = self.three.graph
def testNodeZeroHasNoDeviceSummaryInfo(self):
- message = "^^colocation_node:zero^^"
+ message = "{{colocation_node zero}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No device assignments were active", result)
def testNodeOneHasExactlyOneInterpolatedDevice(self):
- message = "^^colocation_node:one^^"
+ message = "{{colocation_node one}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertEqual(2, result.count("tf.device(/cpu)"))
def testNodeTwoHasTwoInterpolatedDevice(self):
- message = "^^colocation_node:two^^"
+ message = "{{colocation_node two}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertEqual(2, result.count("tf.device(/cpu)"))
self.assertEqual(2, result.count("tf.device(/cpu:0)"))
def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
- message = "^^colocation_node:three^^"
+ message = "{{colocation_node three}}"
result = error_interpolation.interpolate(message, self.graph)
num_devices = result.count("tf.device")
self.assertEqual(2, num_devices)
@@ -256,12 +261,12 @@ class InterpolateColocationSummaryTest(test.TestCase):
self.graph = node_three.graph
def testNodeThreeHasColocationInterpolation(self):
- message = "^^colocation_node:Three_with_one^^"
+ message = "{{colocation_node Three_with_one}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(One)", result)
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
- message = "^^colocation_node:Four_with_three^^"
+ message = "{{colocation_node Four_with_three}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(Three_with_one)", result)
self.assertNotIn(
@@ -269,13 +274,13 @@ class InterpolateColocationSummaryTest(test.TestCase):
"Node One should not appear in Four_with_three's summary:\n%s" % result)
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
- message = "^^colocation_node:Five_with_one_with_two^^"
+ message = "{{colocation_node Five_with_one_with_two}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(One)", result)
self.assertIn("colocate_with(Two)", result)
def testColocationInterpolationForNodeLackingColocation(self):
- message = "^^colocation_node:One^^"
+ message = "{{colocation_node One}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No node-device colocations", result)
self.assertNotIn("Two", result)
diff --git a/tensorflow/python/framework/file_system_test.py b/tensorflow/python/framework/file_system_test.py
index 5eb59141a2..6901715e5d 100644
--- a/tensorflow/python/framework/file_system_test.py
+++ b/tensorflow/python/framework/file_system_test.py
@@ -37,7 +37,7 @@ class FileSystemTest(test.TestCase):
load_library.load_file_system_library(file_system_library)
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.WholeFileReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
queue.enqueue_many([["test://foo"]]).run()
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index ee723bacaf..903768a039 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -419,7 +419,7 @@ class FunctionTest(test.TestCase):
with ops.control_dependencies([z]):
return x * 2
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
z = Foo(constant_op.constant(3.0))
self.assertAllEqual(z.eval(), 6.0)
@@ -434,7 +434,7 @@ class FunctionTest(test.TestCase):
# Foo contains a stateful op (Assert).
self.assertEqual([("Assert", "Assert")], Foo.stateful_ops)
g = ops.Graph()
- with g.as_default(), self.test_session():
+ with g.as_default(), self.cached_session():
self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion failed.*-3"):
@@ -448,7 +448,7 @@ class FunctionTest(test.TestCase):
[control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]):
return array_ops.identity(x)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, MyFn(1.0).eval())
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion"):
@@ -667,7 +667,7 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testNestedDefinedFunction(self):
@@ -683,7 +683,7 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testUnusedFunction(self):
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 18e7d8aa14..2b4d8e7299 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -396,7 +396,7 @@ class ImportGraphDefTest(test.TestCase):
# Run the imported graph.
# TODO(b/76173421): make this work (currently DCHECKS)
- # with self.test_session() as sess:
+ # with self.cached_session() as sess:
# sess.run(imported_init)
# self.assertEqual(sess.run(imported_var), 1.0)
# self.assertEqual(sess.run(imported_assign), 2.0)
@@ -417,7 +417,7 @@ class ImportGraphDefTest(test.TestCase):
imported_r, = importer.import_graph_def(graph_def,
return_elements=[r.name])
self.assertEqual(imported_r.name, "import/" + r.name)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(imported_r), 10)
def testImportWhileLoopInCond(self):
@@ -436,7 +436,7 @@ class ImportGraphDefTest(test.TestCase):
pred = array_ops.placeholder(dtypes.bool)
out = control_flow_ops.cond(pred, ImportFn,
lambda: constant_op.constant(1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(out, {pred: True}), 10)
self.assertEqual(sess.run(out, {pred: False}), 1)
@@ -457,7 +457,7 @@ class ImportGraphDefTest(test.TestCase):
out = control_flow_ops.while_loop(
lambda i: i < 2, ImportFn, [0],
shape_invariants=[tensor_shape.TensorShape(None)])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(out), 10)
def testTypeMismatchInGraphDef(self):
@@ -929,7 +929,7 @@ class ImportGraphDefTest(test.TestCase):
input_map={"a:0": constant_op.constant(5.0)},
name="",
return_elements=["id:0"])
- with self.test_session():
+ with self.cached_session():
self.assertEqual(5.0, t.eval())
def testInvalidInputForReturnOperations(self):
@@ -958,7 +958,7 @@ class ImportGraphDefTest(test.TestCase):
array_ops.stack([c, c], name="pack")
gdef = g.as_graph_def()
- with self.test_session():
+ with self.cached_session():
pack, = importer.import_graph_def(gdef, return_elements=["pack"])
self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])
@@ -1063,7 +1063,7 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual([10], biases_grad.get_shape())
def testLargeGraph(self):
- with self.test_session():
+ with self.cached_session():
# The default message byte limit is 64M. Ours is 2G with a warning at 512.
# Adding a 130M entries float32 tensor should exceed the warning, but not
# the hard limit.
@@ -1254,7 +1254,7 @@ class ImportGraphDefTest(test.TestCase):
z = TestFunc()
- with self.test_session():
+ with self.cached_session():
z_val = z.eval()
self.assertEqual(z_val, -2.0)
@@ -1284,7 +1284,7 @@ class ImportGraphDefTest(test.TestCase):
z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
input_map=input_map)[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
z1_val, z2_val = sess.run((z1, z2))
self.assertAllEqual(z1_val, z2_val)
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 6e5f7aafac..fc98b91a01 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -117,7 +117,7 @@ class SimpleMetaGraphTest(test.TestCase):
self.assertEqual(new_output_value, output_value)
def testStrippedOpListNestedFunctions(self):
- with self.test_session():
+ with self.cached_session():
# Square two levels deep
@function.Defun(dtypes.int32)
def f0(x):
@@ -169,7 +169,7 @@ class SimpleMetaGraphTest(test.TestCase):
# and "Tout" maps to complex64. Since these attr values map to their
# defaults, they must be stripped unless stripping of default attrs is
# disabled.
- with self.test_session():
+ with self.cached_session():
real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
@@ -212,7 +212,8 @@ class SimpleMetaGraphTest(test.TestCase):
def testDefaultAttrStrippingNestedFunctions(self):
"""Verifies that default attributes are stripped from function node defs."""
- with self.test_session():
+ with self.cached_session():
+
@function.Defun(dtypes.float32, dtypes.float32)
def f0(i, j):
return math_ops.complex(i, j, name="double_nested_complex")
@@ -251,7 +252,7 @@ class SimpleMetaGraphTest(test.TestCase):
meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
meta_info_def.stripped_op_list.op.add()
- with self.test_session():
+ with self.cached_session():
meta_graph_def = meta_graph.create_meta_graph_def(
meta_info_def=meta_info_def, graph_def=graph_def,
strip_default_attrs=True)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 4cfd639bf9..75678cbc01 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -55,6 +55,7 @@ from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils
from tensorflow.python.util import lock_util
from tensorflow.python.util import tf_contextlib
@@ -5363,6 +5364,7 @@ def enable_eager_execution(config=None,
computational graph).
For example:
+
```python
tf.enable_eager_execution()
@@ -5807,11 +5809,8 @@ class GraphKeys(object):
_STREAMING_MODEL_PORTS = "streaming_model_ports"
@decorator_utils.classproperty
+ @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.")
def VARIABLES(cls): # pylint: disable=no-self-argument
- logging.log_first_n(logging.WARN,
- "VARIABLES collection name is deprecated, please use "
- "GLOBAL_VARIABLES instead; VARIABLES will be removed "
- "after 2017-03-02.", 1)
return cls.GLOBAL_VARIABLES
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index ced0581402..d59adf3d48 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -58,12 +58,12 @@ ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
class ResourceTest(test_util.TensorFlowTestCase):
def testBuildGraph(self):
- with self.test_session():
+ with self.cached_session():
pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
test_ops.resource_create_op(pt).run()
def testInitialize(self):
- with self.test_session():
+ with self.cached_session():
handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
resources.register_resource(
handle=handle,
@@ -100,35 +100,35 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
pass
def testAddShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.zeros([2, 3])
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual([2, 3], c.shape)
def testUnknownDim(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
c = a + b
self.assertEqual([2, None, 3], c.shape.as_list())
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual(tensor_shape.unknown_shape(), c.shape)
def testScalarShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
b = array_ops.ones([])
c = a + b
self.assertEqual(tensor_shape.scalar(), c.shape)
def testShapeFunctionError(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.ones([1, 2, 3])
b = array_ops.ones([4, 5, 6])
with self.assertRaisesRegexp(
@@ -141,7 +141,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
class IndexedSlicesTest(test_util.TensorFlowTestCase):
def testToTensor(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
dense_shape = constant_op.constant([3, 2])
@@ -150,7 +150,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]])
def testNegation(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = -ops.IndexedSlices(values, indices)
@@ -158,7 +158,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x.indices.eval(), [0, 2])
def testScalarMul(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
@@ -307,14 +307,14 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
def testConvertToTensorNestedArray(self):
- with self.test_session():
+ with self.cached_session():
values = [[2], [3], [5], [7]]
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, tensor.eval())
def testShapeTuple(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(1)
self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access
@@ -328,14 +328,14 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(converted, ops.EagerTensor))
def testConvertToTensorNestedTuple(self):
- with self.test_session():
+ with self.cached_session():
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, ops.convert_to_tensor(values).eval())
def testConvertToTensorNestedTensors(self):
- with self.test_session():
+ with self.cached_session():
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(
[constant_op.constant(row) for row in values])
@@ -347,25 +347,25 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(values, tensor.eval())
def testConvertToTensorNestedMix(self):
- with self.test_session():
+ with self.cached_session():
values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval())
def testConvertToTensorPreferred(self):
- with self.test_session():
+ with self.cached_session():
values = [2, 3, 5, 7]
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
self.assertEqual(dtypes.float32, tensor.dtype)
- with self.test_session():
+ with self.cached_session():
# Convert empty tensor to anything.
values = []
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
self.assertEqual(dtypes.int64, tensor.dtype)
- with self.test_session():
+ with self.cached_session():
# The preferred dtype is a type error and will convert to
# float32 instead.
values = [1.23]
@@ -941,7 +941,7 @@ class NameStackTest(test_util.TensorFlowTestCase):
self.assertEqual("bar_2", g.unique_name("bar"))
def testNameAndVariableScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with sess.graph.name_scope("l0"):
with variable_scope.variable_scope("l1"):
with sess.graph.name_scope("l1") as scope:
@@ -2164,7 +2164,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
g = ops.Graph()
with g.as_default():
- with self.test_session():
+ with self.cached_session():
# First ensure that graphs that are not building functions are
# not escaped.
function_with_variables("foo")
@@ -2416,11 +2416,11 @@ class AttrScopeTest(test_util.TensorFlowTestCase):
return (a, b)
def testNoLabel(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual((None, None), self._get_test_attrs())
def testLabelMap(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a1 = self._get_test_attrs()
with sess.graph._attr_scope({
"_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
@@ -2454,12 +2454,12 @@ ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
class KernelLabelTest(test_util.TensorFlowTestCase):
def testNoLabel(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(b"My label is: default",
test_ops.kernel_label().eval())
def testLabelMap(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_1 = test_ops.kernel_label()
# pylint: disable=protected-access
with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
@@ -2900,7 +2900,7 @@ class NameScopeTest(test_util.TensorFlowTestCase):
class TracebackTest(test_util.TensorFlowTestCase):
def testTracebackWithStartLines(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(2.0)
sess.run(
a,
diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc
index f2270342b0..f6aef5bc50 100644
--- a/tensorflow/python/framework/python_op_gen_internal.cc
+++ b/tensorflow/python/framework/python_op_gen_internal.cc
@@ -15,18 +15,20 @@ limitations under the License.
#include "tensorflow/python/framework/python_op_gen_internal.h"
+#include <float.h>
#include <stdio.h>
+#include <iomanip>
#include <sstream>
#include <unordered_map>
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
-#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
@@ -435,7 +437,12 @@ string AttrValueToPython(const string& type, const AttrValue& value,
if (std::isnan(value.f()) || std::isinf(value.f())) {
return strings::StrCat("float('", value.f(), "')");
} else {
- return strings::StrCat(value.f());
+ // Use locale-independent conversion.
+ static_assert(FLT_DIG < 10, "FLT_DIG is too big");
+ std::ostringstream s;
+ s.imbue(std::locale::classic());
+ s << std::setprecision(FLT_DIG) << value.f();
+ return s.str();
}
} else if (type == "bool") {
return value.b() ? "True" : "False";
diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py
index 2bcfbc17df..22423c4f58 100644
--- a/tensorflow/python/framework/sparse_tensor_test.py
+++ b/tensorflow/python/framework/sparse_tensor_test.py
@@ -45,7 +45,7 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
self.assertEqual(sp.dense_shape.dtype, dtypes.int64)
self.assertEqual(sp.get_shape(), (4, 5))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value = sp.eval()
self.assertAllEqual(indices, value.indices)
self.assertAllEqual(values, value.values)
@@ -81,14 +81,14 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
def test_convert_dense(self):
- with self.test_session():
+ with self.cached_session():
value = [42, 43]
from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor(
value)
self.assertAllEqual(value, from_value.eval())
def test_convert_sparse(self):
- with self.test_session():
+ with self.cached_session():
indices = [[0, 1], [1, 0]]
values = [42, 43]
shape = [2, 2]
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index d6de45fdc4..1d594e4078 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -65,7 +65,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertFalse(c0.op in d.op.control_inputs)
self.assertTrue(c.op in d.op.control_inputs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c_out = sess.run([c])
n_out = sess.run([n])
d_out = sess.run([d])
@@ -144,7 +144,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
b = subscribe.subscribe(b,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c_out = sess.run([c])
d_out = sess.run([d])
@@ -204,7 +204,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertIs(c_sub, c_sub3)
# Expect the three side effect graphs to have been evaluated.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([c_sub])
self.assertIn('graph1', shared)
self.assertIn('graph2', shared)
@@ -227,7 +227,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
v1, lambda t: script_ops.py_func(sub, [t], [t.dtype]))
self.assertTrue(subscribe._is_subscribed_identity(v1_sub))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize the variables first.
sess.run([v1.initializer])
sess.run([v2.initializer])
@@ -272,7 +272,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertIs(tensor_array_sub, tensor_array.handle)
self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([reader])
self.assertEqual(0, len(shared))
@@ -303,7 +303,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
subscribe.subscribe(sparse_add.op.outputs,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([neg])
# All three ops have been processed.
@@ -374,7 +374,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
# Verify that sub(x1) and sub(branch) are not.
self.assertIsNot(context(subscriptions[0]), context(subscriptions[1]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(cond)
self.assertEqual(3, len(results))
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index 11b681d544..3c2a736fb9 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -606,8 +606,8 @@ class TensorShape(object):
slice.
Raises:
- ValueError: If `key` is a slice, and any of its elements are negative, or
- if `self` is completely unknown and the step is set.
+ ValueError: If `key` is a slice and `self` is completely unknown and
+ the step is set.
"""
if self._dims is not None:
if isinstance(key, slice):
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index b14290c203..26170b000d 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -367,7 +367,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
A `TensorProto`. Depending on the type, it may contain data in the
"tensor_content" attribute, which is not directly useful to Python programs.
To access the values you should convert the proto back to a numpy ndarray
- with `tensor_util.MakeNdarray(proto)`.
+ with `tf.make_ndarray(proto)`.
If `values` is a `TensorProto`, it is immediately returned; `dtype` and
`shape` are ignored.
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 395cf43b3f..bdf759f220 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -768,7 +768,7 @@ class TensorUtilTest(test.TestCase):
def __array__(self, dtype=None):
return np.asarray(self.array, dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ma = MockArray(np.array([10, 20, 30]))
t = ops.convert_to_tensor(ma)
a = sess.run(t)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 6d03e956da..b34330aa2a 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -465,29 +465,31 @@ def assert_no_new_pyobjects_executing_eagerly(f):
f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
- collection_sizes_before = {
- collection: len(ops.get_collection(collection))
- for collection in ops.get_default_graph().collections
- }
+ if ops.has_default_graph():
+ collection_sizes_before = {
+ collection: len(ops.get_collection(collection))
+ for collection in ops.get_default_graph().collections
+ }
for _ in range(3):
f(self, **kwargs)
# Note that gc.get_objects misses anything that isn't subject to garbage
# collection (C types). Collections are a common source of leaks, so we
# test for collection sizes explicitly.
- for collection_key in ops.get_default_graph().collections:
- collection = ops.get_collection(collection_key)
- size_before = collection_sizes_before.get(collection_key, 0)
- if len(collection) > size_before:
- raise AssertionError(
- ("Collection %s increased in size from "
- "%d to %d (current items %s).") % (collection_key, size_before,
- len(collection), collection))
- # Make sure our collection checks don't show up as leaked memory by
- # removing references to temporary variables.
- del collection
- del collection_key
- del size_before
- del collection_sizes_before
+ if ops.has_default_graph():
+ for collection_key in ops.get_default_graph().collections:
+ collection = ops.get_collection(collection_key)
+ size_before = collection_sizes_before.get(collection_key, 0)
+ if len(collection) > size_before:
+ raise AssertionError(
+ ("Collection %s increased in size from "
+ "%d to %d (current items %s).") %
+ (collection_key, size_before, len(collection), collection))
+ # Make sure our collection checks don't show up as leaked memory by
+ # removing references to temporary variables.
+ del collection
+ del collection_key
+ del size_before
+ del collection_sizes_before
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())
@@ -535,15 +537,16 @@ def assert_no_new_tensors(f):
tensors_before = set(
id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
- if context.executing_eagerly():
- f(self, **kwargs)
- ops.reset_default_graph()
- else:
- # Run the test in a new graph so that collections get cleared when it's
- # done, but inherit the graph key so optimizers behave.
- outside_graph_key = ops.get_default_graph()._graph_key
- with ops.Graph().as_default():
- ops.get_default_graph()._graph_key = outside_graph_key
+ outside_executed_eagerly = context.executing_eagerly()
+ # Run the test in a new graph so that collections get cleared when it's
+ # done, but inherit the graph key so optimizers behave.
+ outside_graph_key = ops.get_default_graph()._graph_key
+ with ops.Graph().as_default():
+ ops.get_default_graph()._graph_key = outside_graph_key
+ if outside_executed_eagerly:
+ with context.eager_mode():
+ f(self, **kwargs)
+ else:
f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.
@@ -1072,13 +1075,9 @@ class TensorFlowTestCase(googletest.TestCase):
if context.executing_eagerly():
yield None
else:
- sess = self._create_session(graph, config, use_gpu, force_gpu)
- with self._constrain_devices_and_set_default(
- sess, use_gpu, force_gpu) as constrained_sess:
- # We need to do this to make sure the session closes, otherwise, even
- # if the user does with self.session():, it will not close the session.
- with constrained_sess:
- yield constrained_sess
+ with self._create_session(graph, config, force_gpu) as sess:
+ with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu):
+ yield sess
@contextlib.contextmanager
def cached_session(self,
@@ -1126,10 +1125,11 @@ class TensorFlowTestCase(googletest.TestCase):
if context.executing_eagerly():
yield None
else:
- with self._get_cached_session(
- graph, config, use_gpu, force_gpu,
- crash_if_inconsistent_args=True) as sess:
- yield sess
+ sess = self._get_cached_session(
+ graph, config, force_gpu, crash_if_inconsistent_args=True)
+ with self._constrain_devices_and_set_default(sess, use_gpu,
+ force_gpu) as cached:
+ yield cached
@contextlib.contextmanager
def test_session(self,
@@ -1145,10 +1145,11 @@ class TensorFlowTestCase(googletest.TestCase):
yield None
else:
if graph is None:
- with self._get_cached_session(
- graph, config, use_gpu, force_gpu,
- crash_if_inconsistent_args=False) as sess:
- yield sess
+ sess = self._get_cached_session(
+ graph, config, force_gpu, crash_if_inconsistent_args=False)
+ with self._constrain_devices_and_set_default(sess, use_gpu,
+ force_gpu) as cached:
+ yield cached
else:
with self.session(graph, config, use_gpu, force_gpu) as sess:
yield sess
@@ -1326,9 +1327,17 @@ class TensorFlowTestCase(googletest.TestCase):
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
a = self._GetNdArray(a)
b = self._GetNdArray(b)
- self.assertEqual(
- a.shape, b.shape,
- "Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
+ # When the array rank is small, print its contents. Numpy array printing is
+ # implemented using inefficient recursion so prints can cause tests to
+ # time out.
+ if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
+ shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
+ "%s.") % (a.shape, b.shape, b)
+ else:
+ shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
+ b.shape)
+ self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
+
msgs = [msg]
if not np.allclose(a, b, rtol=rtol, atol=atol):
# Add more details than np.testing.assert_allclose.
@@ -1836,91 +1845,69 @@ class TensorFlowTestCase(googletest.TestCase):
with sess.graph.device("/cpu:0"):
yield sess
- def _create_session(self, graph, config, use_gpu, force_gpu):
+ def _create_session(self, graph, config, force_gpu):
"""See session() for details."""
- if context.executing_eagerly():
- return None
- else:
+ def prepare_config(config):
+ """Returns a config for sessions.
- def prepare_config(config):
- """Returns a config for sessions.
-
- Args:
- config: An optional config_pb2.ConfigProto to use to configure the
- session.
- Returns:
- A config_pb2.ConfigProto object.
- """
- if config is None:
- config = config_pb2.ConfigProto()
- config.allow_soft_placement = not force_gpu
- config.gpu_options.per_process_gpu_memory_fraction = 0.3
- elif force_gpu and config.allow_soft_placement:
- config = config_pb2.ConfigProto().CopyFrom(config)
- config.allow_soft_placement = False
- # Don't perform optimizations for tests so we don't inadvertently run
- # gpu ops on cpu
- config.graph_options.optimizer_options.opt_level = -1
- config.graph_options.rewrite_options.constant_folding = (
- rewriter_config_pb2.RewriterConfig.OFF)
- config.graph_options.rewrite_options.arithmetic_optimization = (
- rewriter_config_pb2.RewriterConfig.OFF)
- return config
-
- return ErrorLoggingSession(graph=graph, config=prepare_config(config))
+ Args:
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+
+ Returns:
+ A config_pb2.ConfigProto object.
+ """
+ if config is None:
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = not force_gpu
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3
+ elif force_gpu and config.allow_soft_placement:
+ config = config_pb2.ConfigProto().CopyFrom(config)
+ config.allow_soft_placement = False
+ # Don't perform optimizations for tests so we don't inadvertently run
+ # gpu ops on cpu
+ config.graph_options.optimizer_options.opt_level = -1
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ config.graph_options.rewrite_options.arithmetic_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ return config
+
+ return ErrorLoggingSession(graph=graph, config=prepare_config(config))
- @contextlib.contextmanager
def _get_cached_session(self,
graph=None,
config=None,
- use_gpu=False,
force_gpu=False,
crash_if_inconsistent_args=True):
"""See cached_session() for documentation."""
- if context.executing_eagerly():
- yield None
+ if self._cached_session is None:
+ sess = self._create_session(
+ graph=graph, config=config, force_gpu=force_gpu)
+ self._cached_session = sess
+ self._cached_graph = graph
+ self._cached_config = config
+ self._cached_force_gpu = force_gpu
+ return sess
else:
- if self._cached_session is None:
- sess = self._create_session(
- graph=graph, config=config, use_gpu=use_gpu, force_gpu=force_gpu)
- self._cached_session = sess
- self._cached_graph = graph
- self._cached_config = config
- self._cached_use_gpu = use_gpu
- self._cached_force_gpu = force_gpu
- with self._constrain_devices_and_set_default(
- sess, use_gpu, force_gpu) as constrained_sess:
- yield constrained_sess
- else:
- if crash_if_inconsistent_args and self._cached_graph is not graph:
- raise ValueError("The graph used to get the cached session is "
- "different than the one that was used to create the "
- "session. Maybe create a new session with "
- "self.session()")
- if crash_if_inconsistent_args and self._cached_config is not config:
- raise ValueError("The config used to get the cached session is "
- "different than the one that was used to create the "
- "session. Maybe create a new session with "
- "self.session()")
- if crash_if_inconsistent_args and self._cached_use_gpu is not use_gpu:
- raise ValueError(
- "The use_gpu value used to get the cached session is "
- "different than the one that was used to create the "
- "session. Maybe create a new session with "
- "self.session()")
- if crash_if_inconsistent_args and (self._cached_force_gpu is
- not force_gpu):
- raise ValueError(
- "The force_gpu value used to get the cached session is "
- "different than the one that was used to create the "
- "session. Maybe create a new session with "
- "self.session()")
- # If you modify this logic, make sure to modify it in _create_session
- # as well.
- sess = self._cached_session
- with self._constrain_devices_and_set_default(
- sess, use_gpu, force_gpu) as constrained_sess:
- yield constrained_sess
+ if crash_if_inconsistent_args and self._cached_graph is not graph:
+ raise ValueError("The graph used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and self._cached_config is not config:
+ raise ValueError("The config used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and (self._cached_force_gpu is
+ not force_gpu):
+ raise ValueError(
+ "The force_gpu value used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ return self._cached_session
@tf_export("test.create_local_cluster")
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index c9b5d46f98..22189afa59 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -71,9 +71,6 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.cached_session(graph=ops.Graph()) as sess2:
pass
with self.assertRaises(ValueError):
- with self.cached_session(use_gpu=True) as sess2:
- pass
- with self.assertRaises(ValueError):
with self.cached_session(force_gpu=True) as sess2:
pass
# We make sure that test_session will cache the session even after the