# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for tensorflow.python.framework.ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import gc import os import threading import weakref from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables import tensorflow.python.ops.gradients # pylint: disable=unused-import from tensorflow.python.platform import googletest from tensorflow.python.util import compat ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn) class ResourceTest(test_util.TensorFlowTestCase): def testBuildGraph(self): with self.cached_session(): pt = test_ops.stub_resource_handle_op(container="a", shared_name="b") test_ops.resource_create_op(pt).run() def testInitialize(self): with self.cached_session(): handle = test_ops.stub_resource_handle_op(container="a", shared_name="b") resources.register_resource( handle=handle, create_op=test_ops.resource_create_op(handle), is_initialized_op=test_ops.resource_initialized_op(handle)) self.assertEquals( len( resources.report_uninitialized_resources( resources.shared_resources()).eval()), 1) resources.initialize_resources(resources.shared_resources()).run() self.assertEquals( len( resources.report_uninitialized_resources( resources.shared_resources()).eval()), 0) class TensorAndShapeTest(test_util.TensorFlowTestCase): def testShape(self): op = ops.Operation( ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) t = op.outputs[0] self.assertEqual(tensor_shape.unknown_shape(), t.get_shape()) t.set_shape([1, 2, 3]) self.assertEqual([1, 2, 3], t.get_shape()) def testIterable(self): op = ops.Operation( ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) t = op.outputs[0] self.assertTrue(isinstance(t, ops.Tensor)) with self.assertRaisesRegexp(TypeError, "iter"): for _ in t: pass def testAddShape(self): with self.cached_session(): a = array_ops.zeros([2, 3]) b = array_ops.ones([1, 3]) c = a + b self.assertEqual([2, 3], c.shape) def testUnknownDim(self): with self.cached_session(): a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) c = a + b self.assertEqual([2, None, 3], c.shape.as_list()) def testUnknownShape(self): with self.cached_session(): a = array_ops.placeholder(dtype=dtypes.float32, shape=None) b = array_ops.ones([1, 3]) c = a + b self.assertEqual(tensor_shape.unknown_shape(), c.shape) def testScalarShape(self): with self.cached_session(): a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) b = array_ops.ones([]) c = a + b self.assertEqual(tensor_shape.scalar(), c.shape) def testShapeFunctionError(self): with self.cached_session(): a = array_ops.ones([1, 2, 3]) b = array_ops.ones([4, 5, 6]) with self.assertRaisesRegexp( ValueError, r"Dimensions must be equal, but are 2 and 5 for 'add' \(op: 'Add'\) " r"with input shapes: \[1,2,3\], \[4,5,6\]."): _ = a + b class IndexedSlicesTest(test_util.TensorFlowTestCase): def testToTensor(self): with self.cached_session(): values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) indices = constant_op.constant([0, 2]) dense_shape = constant_op.constant([3, 2]) x = ops.IndexedSlices(values, indices, dense_shape) tensor = ops.convert_to_tensor(x, name="tensor") self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]]) def testNegation(self): with self.cached_session(): values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) indices = constant_op.constant([0, 2]) x = -ops.IndexedSlices(values, indices) self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]]) self.assertAllEqual(x.indices.eval(), [0, 2]) def testScalarMul(self): with self.cached_session(): values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) indices = constant_op.constant([0, 2]) x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices)) self.assertAllEqual(x.values.eval(), [[-4, -6], [-10, -14]]) self.assertAllEqual(x.indices.eval(), [0, 2]) class NodeDefConstructorTest(test_util.TensorFlowTestCase): def testNoArgs(self): nodedef = ops._NodeDef("None", "bar") self.assertProtoEquals("op: 'None' name: 'bar'", nodedef) def testArgs(self): nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", nodedef) nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j")) self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef) def _apply_op(g, *args, **kwargs): op = g.create_op(*args, **kwargs) if len(op.outputs) == 1: return op.outputs[0] else: return op.outputs class OperationTest(test_util.TensorFlowTestCase): def testNoInputs(self): op = test_ops.float_output_string_output(name="myop").a.op self.assertEqual(2, len(op.values())) self.assertEqual(0, len(op.inputs)) self.assertEqual("myop", op.name) float_t, label_str_t = op.values() self.assertEqual(dtypes.float32, float_t.dtype) self.assertEqual(op, float_t.op) self.assertEqual(0, float_t._value_index) self.assertEqual(0, len(float_t.consumers())) self.assertEqual("myop", float_t._as_node_def_input()) self.assertEqual(dtypes.string, label_str_t.dtype) self.assertEqual(op, label_str_t.op) self.assertEqual(1, label_str_t._value_index) self.assertEqual(0, len(label_str_t.consumers())) self.assertEqual("myop:1", label_str_t._as_node_def_input()) self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'", op.node_def) def testNoOutputs(self): op1 = test_ops.float_output(name="myop1").op float_t, = op1.values() op2 = test_ops.float_input(float_t, name="myop2") self.assertEqual(0, len(op2.values())) self.assertEqual(1, len(op2.inputs)) self.assertIs(float_t, op2.inputs[0]) self.assertEqual(1, len(float_t.consumers())) self.assertEqual(op2, float_t.consumers()[0]) self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def) self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'", op2.node_def) def testInputsAndOutputs(self): op1 = test_ops.float_output(name="myop1").op self.assertEqual(1, len(op1.values())) float1_t, = op1.values() op2 = test_ops.float_output_string_output(name="myop2").a.op self.assertEqual(2, len(op2.values())) float2_t, label2_str_t = op2.values() # Note that we consume label2_str_t twice here. op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op self.assertEqual(2, len(op3.values())) self.assertEqual(1, len(float1_t.consumers())) self.assertEqual(op3, float1_t.consumers()[0]) self.assertEqual(0, len(float2_t.consumers())) self.assertEqual(2, len(label2_str_t.consumers())) self.assertEqual(op3, label2_str_t.consumers()[0]) self.assertEqual(op3, label2_str_t.consumers()[1]) self.assertProtoEquals(""" op:'Foo2' name:'myop3' input:'myop1' input:'myop2:1' input:'myop2:1' """, op3.node_def) def testDeviceObject(self): op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], []) op._set_device("/job:goo/device:GPU:0") self.assertProtoEquals( "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], []) op._set_device( pydev.DeviceSpec( job="muu", device_type="CPU", device_index=0)) self.assertProtoEquals( "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def) def testReferenceInput(self): g = ops.Graph() op1 = ops.Operation( ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], [dtypes.float32_ref, dtypes.float32]) self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) self.assertEquals([], list(op1.inputs)) ref_t, nonref_t = op1.values() # NOTE(mrry): Must specify input_types to preserve ref-typed input. op2 = ops.Operation( ops._NodeDef("RefInputFloatInput", "op2"), g, [ref_t, nonref_t], [], input_types=[dtypes.float32_ref, dtypes.float32]) self.assertProtoEquals( "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", op2.node_def) self.assertEquals([ref_t, nonref_t], list(op2.inputs)) op3 = ops.Operation( ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) self.assertProtoEquals( "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", op3.node_def) def testInvalidNames(self): g = ops.Graph() with self.assertRaises(ValueError): ops.Operation(ops._NodeDef("op", ""), g) with self.assertRaises(ValueError): ops.Operation(ops._NodeDef("op", "_invalid"), g) with self.assertRaises(ValueError): ops.Operation(ops._NodeDef("op", "-invalid"), g) with self.assertRaises(ValueError): ops.Operation(ops._NodeDef("op", "/invalid"), g) with self.assertRaises(ValueError): ops.Operation(ops._NodeDef("op", "invalid:0"), g) def testNoShapeFunction(self): op = test_ops.a() self.assertEqual(tensor_shape.unknown_shape(), op.get_shape()) def testConvertToTensorNestedArray(self): with self.cached_session(): values = [[2], [3], [5], [7]] tensor = ops.convert_to_tensor(values) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(values, tensor.eval()) def testShapeTuple(self): with self.cached_session(): c = constant_op.constant(1) self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access def testConvertToTensorEager(self): with context.eager_mode(): t = constant_op.constant(1) self.assertTrue(isinstance(t, ops.EagerTensor)) converted = ops.convert_to_tensor(t) self.assertTrue(isinstance(converted, ops.EagerTensor)) converted = ops.convert_to_tensor(1) self.assertTrue(isinstance(converted, ops.EagerTensor)) def testConvertToTensorNestedTuple(self): with self.cached_session(): values = ((2,), (3,), (5,), (7,)) tensor = ops.convert_to_tensor(values) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(values, ops.convert_to_tensor(values).eval()) def testConvertToTensorNestedTensors(self): with self.cached_session(): values = ((2,), (3,), (5,), (7,)) tensor = ops.convert_to_tensor( [constant_op.constant(row) for row in values]) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(values, tensor.eval()) tensor = ops.convert_to_tensor( [[constant_op.constant(v) for v in row] for row in values]) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(values, tensor.eval()) def testConvertToTensorNestedMix(self): with self.cached_session(): values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7])) tensor = ops.convert_to_tensor(values) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval()) def testConvertToTensorPreferred(self): with self.cached_session(): values = [2, 3, 5, 7] tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) self.assertEqual(dtypes.float32, tensor.dtype) with self.cached_session(): # Convert empty tensor to anything. values = [] tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) self.assertEqual(dtypes.int64, tensor.dtype) with self.cached_session(): # The preferred dtype is a type error and will convert to # float32 instead. values = [1.23] tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) self.assertEqual(dtypes.float32, tensor.dtype) def testConvertToInvalidTensorType(self): with self.assertRaises(TypeError): # Forcing an invalid dtype should fail with a type error. values = [1.23] _ = ops.convert_to_tensor(values, dtype=dtypes.int64) def testNoConvert(self): # Operation cannot be converted to Tensor. op = control_flow_ops.no_op() with self.assertRaisesRegexp(TypeError, r"Can't convert Operation '.*' to Tensor"): ops.convert_to_tensor(op) def testStr(self): node_def = ops._NodeDef("None", "op1") op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32]) self.assertEqual(str(node_def), str(op)) def testRepr(self): op = ops.Operation( ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32]) self.assertEqual("", repr(op)) def testGetAttr(self): op = test_ops.default_attrs() self.assertEqual(op.get_attr("string_val"), b"abc") self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""]) self.assertEqual(op.get_attr("int_val"), 123) self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3]) self.assertEqual(op.get_attr("float_val"), 10.0) self.assertEqual(op.get_attr("float_list_val"), [10.0]) self.assertEqual(op.get_attr("bool_val"), True) self.assertEqual(op.get_attr("bool_list_val"), [True, False]) self.assertEqual(op.get_attr("shape_val"), tensor_shape.as_shape([2, 1]).as_proto()) self.assertEqual(op.get_attr("shape_list_val"), [tensor_shape.as_shape([]).as_proto(), tensor_shape.as_shape([1]).as_proto()]) self.assertEqual(op.get_attr("tensor_val"), tensor_util.make_tensor_proto(1, dtypes.int32)) self.assertEqual(op.get_attr("tensor_list_val"), [tensor_util.make_tensor_proto(1, dtypes.int32)]) type_val = op.get_attr("type_val") # First check that type_val is a DType, because the assertEquals will work # no matter what since DType overrides __eq__ self.assertIsInstance(type_val, dtypes.DType) self.assertEqual(type_val, dtypes.int32) type_list_val = op.get_attr("type_list_val") self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val)) self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32]) @function.Defun(dtypes.float32, func_name="MyFunc") def func(x): return x op = test_ops.func_attr(func) self.assertEqual(op.get_attr("f"), attr_value_pb2.NameAttrList(name="MyFunc")) # Try fetching missing attr with self.assertRaisesRegexp( ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."): op.get_attr("FakeAttr") # TODO(b/65162920): remove this test when users who are directly mutating the # node_def have been updated to proper usage. def testSetAttr(self): op = test_ops.int_attr().op op._set_attr("foo", attr_value_pb2.AttrValue(i=2)) # TODO(skyewm): add node_def check self.assertEqual(op.get_attr("foo"), 2) # TODO(nolivia): test all error cases def testAddControlInput(self): with ops.Graph().as_default(): x = constant_op.constant(1).op y = constant_op.constant(2).op z = constant_op.constant(3).op z._add_control_input(x) # pylint: disable=protected-access self.assertEqual(z.control_inputs, [x]) z._add_control_input(x) # pylint: disable=protected-access self.assertEqual(z.control_inputs, [x]) z._add_control_inputs([x, y, y]) # pylint: disable=protected-access self.assertEqual(z.control_inputs, [x, y]) self.assertEqual(x._control_outputs, [z]) def testRemoveAllControlInputs(self): a = constant_op.constant(1) with ops.control_dependencies([a]): b = constant_op.constant(2) c = constant_op.constant(3) d = constant_op.constant(4) e = constant_op.constant(5) with ops.control_dependencies([a, c]): f = d + e self.assertEqual(a.op.control_inputs, []) self.assertEqual(b.op.control_inputs, [a.op]) self.assertEqual(f.op.control_inputs, [a.op, c.op]) a.op._remove_all_control_inputs() # pylint: disable=protected-access self.assertEqual(a.op.control_inputs, []) b.op._remove_all_control_inputs() # pylint: disable=protected-access self.assertEqual(b.op.control_inputs, []) f.op._remove_all_control_inputs() # pylint: disable=protected-access self.assertEqual(f.op.control_inputs, []) self.assertEqual(list(f.op.inputs), [d, e]) def testControlInputCycle(self): graph = ops.Graph() with graph.as_default(): z = constant_op.constant(0) x = constant_op.constant(1) y = constant_op.constant(2) y.op._add_control_input(z.op) # pylint: disable=protected-access y.op._add_control_input(x.op) # pylint: disable=protected-access x.op._add_control_input(y.op) # pylint: disable=protected-access with self.session(graph=graph) as sess: with self.assertRaisesRegexp( errors.InvalidArgumentError, "Graph is invalid, contains a cycle with 2 nodes"): sess.run(x) def testUpdateInput(self): g = ops.Graph() with g.as_default(): x = constant_op.constant(1) y = constant_op.constant(2) z = x + y z.op._update_input(0, y) # pylint: disable=protected-access self.assertEquals(list(z.op.inputs), [y, y]) self.assertEquals(x.consumers(), []) self.assertEquals(y.consumers(), [z.op, z.op]) with session.Session(graph=g) as sess: self.assertEquals(sess.run(z), 4) z.op._update_input(0, x) # pylint: disable=protected-access self.assertEquals(list(z.op.inputs), [x, y]) self.assertEquals(x.consumers(), [z.op]) self.assertEquals(y.consumers(), [z.op]) with session.Session(graph=g) as sess: self.assertEquals(sess.run(z), 3) z.op._update_input(1, y) # pylint: disable=protected-access self.assertEquals(list(z.op.inputs), [x, y]) self.assertEquals(x.consumers(), [z.op]) self.assertEquals(y.consumers(), [z.op]) with session.Session(graph=g) as sess: self.assertEquals(sess.run(z), 3) def testUpdateInputGraphError(self): g_0 = ops.Graph() g_1 = ops.Graph() with g_0.as_default(): x = constant_op.constant(1) with g_1.as_default(): y = constant_op.constant(2) z = y * 2 with self.assertRaisesRegexp(ValueError, "must be from the same graph"): z.op._update_input(0, x) # pylint: disable=protected-access def testUpdateInputTypeError(self): g = ops.Graph() with g.as_default(): w = constant_op.constant(0) x = constant_op.constant("") y = constant_op.constant(1) z = y + w z.op._update_input(0, x) # pylint: disable=protected-access with session.Session(graph=g) as sess: with self.assertRaisesRegexp( errors.InvalidArgumentError, "Input 0 of node add was passed string from Const_1:0 incompatible " "with expected int32"): sess.run(z) def testUpdateInputShapeError(self): g = ops.Graph() with g.as_default(): w = constant_op.constant(2, shape=[3, 1]) x = constant_op.constant(0, shape=[3, 1]) y = constant_op.constant(1, shape=[2, 2]) z = w + x with self.assertRaisesRegexp( errors.InvalidArgumentError, r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"): z.op._update_input(0, y) # pylint: disable=protected-access def testUpdateInputOutOfRange(self): g = ops.Graph() with g.as_default(): x = constant_op.constant(1) with self.assertRaisesRegexp( errors.OutOfRangeError, r"Cannot update edge. Input index \[1\] is greater than the number of " r"total inputs \[0\]." ): x.op._update_input(1, x) # pylint: disable=protected-access def testOpDef(self): x = constant_op.constant(0) y = constant_op.constant(1) z = x + y self.assertEqual(x.op.op_def.name, "Const") self.assertEqual(len(x.op.op_def.input_arg), 0) self.assertEqual(len(x.op.op_def.output_arg), 1) self.assertEqual(z.op.op_def.name, "Add") self.assertEqual(len(z.op.op_def.input_arg), 2) self.assertEqual(len(z.op.op_def.output_arg), 1) def testInputFromDifferentGraphError(self): g_0 = ops.Graph() g_1 = ops.Graph() with g_0.as_default(): x = constant_op.constant(1) with g_1.as_default(): y = constant_op.constant(2) with self.assertRaisesRegexp(ValueError, "must be from the same graph"): y * x # pylint: disable=pointless-statement def testInputsAreImmutable(self): g = ops.Graph() with g.as_default(): x = test_ops.int_output() op = test_ops.int_input_int_output(x, name="myop").op with self.assertRaisesRegexp( AttributeError, "'_InputList' object has no attribute 'append'"): op.inputs.append(None) class CreateOpTest(test_util.TensorFlowTestCase): def testNodeDefArgs(self): g = ops.Graph() op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") with g.device("/device:GPU:0"): op2 = g.create_op( "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None, name="myop2") op3 = g.create_op( "Foo3", [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]], [dtypes.float32, dtypes.int32], None, name="myop3") self.assertDeviceEqual(None, op1.device) self.assertDeviceEqual("/device:GPU:0", op2.device) self.assertDeviceEqual(None, op3.device) self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def) self.assertProtoEquals( "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'", op2.node_def) self.assertProtoEquals( "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'", op3.node_def) def testReferenceInput(self): g = ops.Graph() op1 = g.create_op( "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], name="op1") self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) ref_t, nonref_t = op1.values() # NOTE(mrry): Must specify input_types to preserve ref-typed input. op2 = g.create_op( "RefInputFloatInput", [ref_t, nonref_t], [], input_types=[dtypes.float32_ref, dtypes.float32], name="op2") self.assertProtoEquals( "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", op2.node_def) op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3") self.assertProtoEquals( "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", op3.node_def) def testFinalized(self): g = ops.Graph() g.finalize() with self.assertRaises(RuntimeError): g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") # Test unfinalize. g._unsafe_unfinalize() g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") # NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation # method. Arguably we should only test the public APIs that depend on this # method. However, this logic is complex and tricky, and it can be difficult to # ascertain if we have adequate coverage (e.g. a graph may run successfully if # the control flow context isn't set properly, but a more complicated use case # that might not be obvious to test will fail). Thus we instead explicitly test # the low-level behavior. class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase): def testBasic(self): g = ops.Graph() with g.as_default(): x = test_ops.int_output() c_op = ops._create_c_op( g, ops._NodeDef("IntInputIntOutput", "myop"), [x], []) op = g._create_op_from_tf_operation(c_op) self.assertEqual(op.name, "myop") self.assertEqual(op.type, "IntInputIntOutput") self.assertEqual(len(op.outputs), 1) self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape()) self.assertEqual(list(op.inputs), [x]) self.assertEqual(op.control_inputs, []) self.assertEqual(op.graph, g) self.assertEqual(x.consumers(), [op]) self.assertIsNotNone(op.traceback) self.assertEqual(g.get_operation_by_name("myop"), op) self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0]) @test_util.enable_c_shapes def testShape(self): g = ops.Graph() with g.as_default(): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], []) op = g._create_op_from_tf_operation(c_op) self.assertEqual(op.name, "myop") self.assertEqual(op.type, "Identity") self.assertEqual(len(op.outputs), 1) self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3)) def testUniqueName(self): g = ops.Graph() with g.as_default(): c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], []) c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], []) op = g._create_op_from_tf_operation(c_op) op2 = g._create_op_from_tf_operation(c_op2) # Create ops with same names as op1 and op2. We expect the new names to be # uniquified. op3 = test_ops.int_output(name="myop").op op4 = test_ops.int_output(name="myop_1").op self.assertEqual(op.name, "myop") self.assertEqual(op2.name, "myop_1") self.assertEqual(op3.name, "myop_2") self.assertEqual(op4.name, "myop_1_1") def testCond(self): g = ops.Graph() with g.as_default(): x = test_ops.int_output() def true_fn(): ops._create_c_op(ops.get_default_graph(), ops._NodeDef("IntInput", "cond/myop"), [x], []) new_ops = g._add_new_tf_operations() self.assertEqual(len(new_ops), 1) return x control_flow_ops.cond(x < 10, true_fn, lambda: x) op = g.get_operation_by_name("cond/myop") self.assertIsNotNone(op) self.assertEqual(op.name, "cond/myop") self.assertEqual(op.type, "IntInput") self.assertEqual(op.outputs, []) op_input = op.inputs[0].op self.assertEqual(op_input.type, "Switch") self.assertEqual(op_input.inputs[0], x) self.assertEqual(op.graph, g) # pylint: disable=protected-access self.assertIsNotNone(op._get_control_flow_context()) self.assertEqual(op._get_control_flow_context().name, "cond/cond_text") # pylint: enable=protected-access def testWhileLoop(self): g = ops.Graph() with g.as_default(): x = test_ops.int_output() def body(i): ops._create_c_op(ops.get_default_graph(), ops._NodeDef("IntInput", "myloop/myop"), [x], []) new_ops = g._add_new_tf_operations() self.assertEqual(len(new_ops), 1) return i control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") op = g.get_operation_by_name("myloop/myop") self.assertIsNotNone(op) self.assertEqual(op.name, "myloop/myop") self.assertEqual(op.type, "IntInput") self.assertEqual(op.outputs, []) op_input = op.inputs[0].op self.assertEqual(op_input.type, "Enter") self.assertEqual(list(op_input.inputs), [x]) self.assertEqual(op.graph, g) # pylint: disable=protected-access self.assertIsNotNone(op._get_control_flow_context()) self.assertEqual(op._get_control_flow_context().name, "myloop/while_context") # pylint: enable=protected-access def testWhileLoopWithInternalControlDep(self): g = ops.Graph() with g.as_default(): x = test_ops.int_output() def body(i): c = constant_op.constant(1.0, name="c") ops._create_c_op(ops.get_default_graph(), ops._NodeDef("IntInput", "myloop/myop"), [x], []) with ops.control_dependencies([c]): new_ops = g._add_new_tf_operations() self.assertEqual(len(new_ops), 1) return i control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") op = g.get_operation_by_name("myloop/myop") self.assertIsNotNone(op) c = g.get_operation_by_name("myloop/c") self.assertIsNotNone(c) # Internal control dep is preserved self.assertEqual(op.control_inputs, [c]) def testWhileLoopWithExternalControlDep(self): g = ops.Graph() with g.as_default(): x = test_ops.int_output() c = constant_op.constant(1.0) def body(i): ops._create_c_op(ops.get_default_graph(), ops._NodeDef("IntInput", "myloop/myop"), [x], []) with ops.control_dependencies([c]): new_ops = g._add_new_tf_operations() self.assertEqual(len(new_ops), 1) return i control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") op = g.get_operation_by_name("myloop/myop") self.assertIsNotNone(op) # External control dep is removed and replaced with internal control dep self.assertNotEqual(op.control_inputs[0], c.op) self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) class ApplyOpTest(test_util.TensorFlowTestCase): def testNodeDefArgs(self): g = ops.Graph() t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") with g.device("/device:GPU:0"): t2 = _apply_op( g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2") t3 = _apply_op( g, "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32], name="myop3") self.assertTrue(isinstance(t1, ops.Tensor)) self.assertTrue(isinstance(t2, list)) self.assertTrue(isinstance(t3, list)) self.assertTrue(isinstance(t3[0], ops.Tensor)) self.assertEqual("myop1", t1._as_node_def_input()) self.assertEqual("myop2", t2[0]._as_node_def_input()) self.assertEqual("myop2:1", t2[1]._as_node_def_input()) self.assertEqual("myop3", t3[0]._as_node_def_input()) # Validate that we got the right ops as well self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def) self.assertProtoEquals( "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'", t2[0].op.node_def) self.assertProtoEquals( "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'", t3[0].op.node_def) def testReferenceInput(self): g = ops.Graph() ref_t, nonref_t = _apply_op( g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], name="op1") self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", ref_t.op.node_def) # NOTE(mrry): Must specify input_types to preserve ref-typed input. out_2 = _apply_op( g, "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32], input_types=[dtypes.float32_ref, dtypes.float32], name="op2") self.assertProtoEquals( "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'", out_2.op.node_def) out_3 = _apply_op( g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32], name="op3") self.assertProtoEquals( "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'", out_3.op.node_def) class NameStackTest(test_util.TensorFlowTestCase): def testBasics(self): g = ops.Graph() self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) self.assertEqual("foo", g.unique_name("foo")) self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False)) self.assertEqual("foo_1", g.unique_name("foo")) self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False)) self.assertEqual("foo_2", g.unique_name("foo")) self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False)) self.assertEqual("foo_1_1", g.unique_name("foo_1")) self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False)) self.assertEqual("foo_1_2", g.unique_name("foo_1")) self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False)) self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2")) with g.name_scope("bar"): self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False)) self.assertEqual("bar/foo", g.unique_name("foo")) self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False)) self.assertEqual("bar/foo_1", g.unique_name("foo")) with g.name_scope(None): self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False)) self.assertEqual("foo_3", g.unique_name("foo")) with g.name_scope("baz"): self.assertEqual( "bar/baz/foo", g.unique_name( "foo", mark_as_used=False)) self.assertEqual("bar/baz/foo", g.unique_name("foo")) self.assertEqual( "bar/baz/foo_1", g.unique_name( "foo", mark_as_used=False)) self.assertEqual("bar/baz/foo_1", g.unique_name("foo")) with g.name_scope("baz"): self.assertEqual( "bar/baz_1/foo", g.unique_name( "foo", mark_as_used=False)) self.assertEqual("bar/baz_1/foo", g.unique_name("foo")) self.assertEqual( "bar/baz_1/foo_1", g.unique_name( "foo", mark_as_used=False)) self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo")) with g.name_scope("quux"): self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False)) self.assertEqual("quux/foo", g.unique_name("foo")) with g.name_scope("bar"): with g.name_scope("baz"): self.assertEqual( "bar_1/baz/foo", g.unique_name( "foo", mark_as_used=False)) self.assertEqual("bar_1/baz/foo", g.unique_name("foo")) self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False)) self.assertEqual("foo_4", g.unique_name("foo")) self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False)) self.assertEqual("bar_2", g.unique_name("bar")) def testNameAndVariableScope(self): with self.cached_session() as sess: with sess.graph.name_scope("l0"): with variable_scope.variable_scope("l1"): with sess.graph.name_scope("l1") as scope: self.assertEqual("l0/l1/l1/", scope) self.assertEqual( "l0/l1/l1/foo", sess.graph.unique_name( "foo", mark_as_used=False)) self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo")) with sess.graph.name_scope("l2") as scope: self.assertEqual("l0/l1/l2/", scope) self.assertEqual( "l0/l1/l2/foo", sess.graph.unique_name( "foo", mark_as_used=False)) self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo")) def testOutOfOrderUniqueName(self): g = ops.Graph() self.assertEqual("foo_2", g.unique_name("foo_2")) self.assertEqual("foo", g.unique_name("foo")) self.assertEqual("foo_1", g.unique_name("foo")) self.assertEqual("foo_3", g.unique_name("foo")) def testUniqueNameCaseInsensitivity(self): g = ops.Graph() self.assertEqual("foo", g.unique_name("foo")) self.assertEqual("Foo_1", g.unique_name("Foo")) with g.name_scope("bar"): self.assertEqual("bar/foo", g.unique_name("foo")) with g.name_scope("Bar"): self.assertEqual("Bar_1/foo", g.unique_name("foo")) def testInvalidNameRaisesError(self): g = ops.Graph() with g.name_scope(""): # Should not raise pass with g.name_scope("foo/"): # Should not raise with g.name_scope("_bar"): # Should not raise pass with self.assertRaises(ValueError): with g.name_scope("foo:0"): pass with self.assertRaises(ValueError): with g.name_scope("_bar"): pass class NameTest(test_util.TensorFlowTestCase): def testGenerateName(self): g = ops.Graph() op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) self.assertEqual("TwoFloatOutputs", op0.name) self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name) self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name) op1 = g.create_op("FloatOutput", [], [dtypes.float32]) self.assertEqual("FloatOutput", op1.name) self.assertEqual("FloatOutput:0", op1.outputs[0].name) op2 = g.create_op("FloatOutput", [], [dtypes.float32]) self.assertEqual("FloatOutput_1", op2.name) self.assertEqual("FloatOutput_1:0", op2.outputs[0].name) op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op") self.assertEqual("my_op", op3.name) self.assertEqual("my_op:0", op3.outputs[0].name) def testNameScope(self): g = ops.Graph() with g.name_scope("foo") as foo: self.assertEqual("foo/", foo) with g.name_scope("foo2") as foo2: self.assertEqual("foo/foo2/", foo2) with g.name_scope(None) as empty1: self.assertEqual("", empty1) with g.name_scope("foo3") as foo3: self.assertEqual("foo3/", foo3) with g.name_scope("") as empty2: self.assertEqual("", empty2) self.assertEqual("FloatOutput", g.create_op("FloatOutput", [], [dtypes.float32]).name) with g.name_scope("bar") as scope: self.assertEqual("bar/FloatOutput", g.create_op("FloatOutput", [], [dtypes.float32]).name) self.assertEqual("bar/FloatOutput_1", g.create_op("FloatOutput", [], [dtypes.float32]).name) # If you use the value from "with .. as", that values is used as-is. self.assertEqual( "bar", g.create_op( "FloatOutput", [], [dtypes.float32], name=scope).name) with g.name_scope("baz") as scope: with g.name_scope("quux"): self.assertEqual("baz/quux/FloatOutput", g.create_op("FloatOutput", [], [dtypes.float32]).name) # If you use the value from the enclosing "with .. as", nothing is pushed. with g.name_scope(scope): self.assertEqual("baz/FloatOutput", g.create_op("FloatOutput", [], [dtypes.float32]).name) self.assertEqual( "baz", g.create_op( "FloatOutput", [], [dtypes.float32], name=scope).name) self.assertEqual( "trailing", g.create_op( "FloatOutput", [], [dtypes.float32], name="trailing/").name) with g.name_scope("bar"): self.assertEqual("bar_1/FloatOutput", g.create_op("FloatOutput", [], [dtypes.float32]).name) with g.name_scope("bar/"): self.assertEqual("bar/FloatOutput_2", g.create_op("FloatOutput", [], [dtypes.float32]).name) class DeviceTest(test_util.TensorFlowTestCase): def testNoDevice(self): g = ops.Graph() op = g.create_op("FloatOutput", [], [dtypes.float32]) self.assertDeviceEqual(None, op.device) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" } """, gd) def testDevicePartialString(self): g = ops.Graph() with g.device("/job:worker/replica:2"): g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/job:worker/replica:2" } """, gd) def testDeviceFull(self): g = ops.Graph() with g.device( pydev.DeviceSpec( job="worker", replica=2, task=0, device_type="CPU", device_index=3)): g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/job:worker/replica:2/task:0/device:CPU:3" } """, gd) def testNesting(self): g = ops.Graph() with g.device("/job:worker/replica:2"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device("/job:worker/replica:3/task:0"): g.create_op("FloatOutput", [], [dtypes.float32]) g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/job:worker/replica:2" } node { name: "FloatOutput_1" op: "FloatOutput" device: "/job:worker/replica:3/task:0" } node { name: "FloatOutput_2" op: "FloatOutput" device: "/job:worker/replica:2" } """, gd) def testNestingString(self): g = ops.Graph() with g.device("/job:worker/replica:2"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device("/job:worker/replica:3/task:0"): g.create_op("FloatOutput", [], [dtypes.float32]) g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/job:worker/replica:2" } node { name: "FloatOutput_1" op: "FloatOutput" device: "/job:worker/replica:3/task:0" } node { name: "FloatOutput_2" op: "FloatOutput" device: "/job:worker/replica:2" } """, gd) def testNestingOverrideGpuCpu(self): g = ops.Graph() with g.device("/job:worker/replica:2/device:CPU:1"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device("/job:worker/replica:2/device:GPU:2"): g.create_op("FloatOutput", [], [dtypes.float32]) g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/job:worker/replica:2/device:CPU:1" } node { name: "FloatOutput_1" op: "FloatOutput" device: "/job:worker/replica:2/device:GPU:2" } node { name: "FloatOutput_2" op: "FloatOutput" device: "/job:worker/replica:2/device:CPU:1" } """, gd) def testNestingWithMergeDeviceFunction(self): g = ops.Graph() with g.device(pydev.merge_device("/device:GPU:0")): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device(pydev.merge_device("/job:worker")): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device(pydev.merge_device("/device:CPU:0")): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device(pydev.merge_device("/job:ps")): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device(pydev.merge_device(None)): g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/device:GPU:0" } node { name: "FloatOutput_1" op: "FloatOutput" device: "/job:worker/device:GPU:0" } node { name: "FloatOutput_2" op: "FloatOutput" device: "/job:worker/device:CPU:0" } node { name: "FloatOutput_3" op: "FloatOutput" device: "/job:ps/device:CPU:0" } node { name: "FloatOutput_4" op: "FloatOutput" device: "/job:ps/device:CPU:0" } """, gd) def testNestingWithDeviceStrings(self): g = ops.Graph() with g.device("/device:GPU:0"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device("/job:worker"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device("/device:CPU:0"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device("/job:ps"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device(""): g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/device:GPU:0" } node { name: "FloatOutput_1" op: "FloatOutput" device: "/job:worker/device:GPU:0" } node { name: "FloatOutput_2" op: "FloatOutput" device: "/job:worker/device:CPU:0" } node { name: "FloatOutput_3" op: "FloatOutput" device: "/job:ps/device:CPU:0" } node { name: "FloatOutput_4" op: "FloatOutput" device: "/job:ps/device:CPU:0" } """, gd) def testNestingWithDeviceStringWildcard(self): g = ops.Graph() with g.device("/device:GPU:7"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device("/device:GPU:*"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device("/device:CPU:*"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device("/device:CPU:5"): g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/device:GPU:7" } node { name: "FloatOutput_1" op: "FloatOutput" device: "/device:GPU:7" } node { name: "FloatOutput_2" op: "FloatOutput" device: "/device:CPU:*" } node { name: "FloatOutput_3" op: "FloatOutput" device: "/device:CPU:5" } """, gd) def testNoneClearsDefault(self): g = ops.Graph() with g.device("/job:worker/replica:2/device:CPU:1"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device(None): g.create_op("FloatOutput", [], [dtypes.float32]) g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/job:worker/replica:2/device:CPU:1" } node { name: "FloatOutput_1" op: "FloatOutput" } node { name: "FloatOutput_2" op: "FloatOutput" device: "/job:worker/replica:2/device:CPU:1" } """, gd) def testNoneIgnoresOuterDeviceFunction(self): g = ops.Graph() with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device(None): g.create_op("FloatOutput", [], [dtypes.float32]) g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/job:worker/replica:2/device:CPU:1" } node { name: "FloatOutput_1" op: "FloatOutput" } node { name: "FloatOutput_2" op: "FloatOutput" device: "/job:worker/replica:2/device:CPU:1" } """, gd) def _overwritingDeviceFunction(self, unused_op): # This device function unconditionally overwrites the device of ops. # # NOTE(mrry): Writing device functions like this is not # recommended. Instead, in most cases you should use # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the # argument to `tf.device()` and the device component will be merged in. return "/job:overwrite" def testOverwritingBehavior(self): g = ops.Graph() with g.device(self._overwritingDeviceFunction): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device("/job:ps"): # Will be overwritten. g.create_op("FloatOutput", [], [dtypes.float32]) with g.device(pydev.merge_device("/job:ps")): # Will be overwritten. g.create_op("FloatOutput", [], [dtypes.float32]) with g.device(None): # Disables overwriting device function with g.device("/job:ps"): g.create_op("FloatOutput", [], [dtypes.float32]) with g.device(None): # Disables overwriting device function with g.device(pydev.merge_device("/job:ps")): g.create_op("FloatOutput", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput" op: "FloatOutput" device: "/job:overwrite" } node { name: "FloatOutput_1" op: "FloatOutput" device: "/job:overwrite" } node { name: "FloatOutput_2" op: "FloatOutput" device: "/job:overwrite" } node { name: "FloatOutput_3" op: "FloatOutput" device: "/job:ps" } node { name: "FloatOutput_4" op: "FloatOutput" device: "/job:ps" } """, gd) class MultithreadedGraphStateTest(test_util.TensorFlowTestCase): class TestThread(threading.Thread): def __init__(self, graph, replica_id): super(MultithreadedGraphStateTest.TestThread, self).__init__() self._graph = graph self._replica_id = replica_id # This thread sets this event when it mutated the graph. The caller can # wait for that. self.has_mutated_graph = threading.Event() # This thread waits for when it should continue. The caller can set this # event. self.should_continue = threading.Event() def run(self): # Mutate a graph's stack, then set `has_mutated_graph`, then wait for # `should_continue`, then add an op to the graph affected by the graph's # stack. raise NotImplementedError("must be implemented in descendants") def testDeviceFunctionStack(self): class DeviceSettingThread(self.TestThread): def run(self): with g.device("/job:worker/replica:{}".format(self._replica_id)): self.has_mutated_graph.set() self.should_continue.wait() self.should_continue.clear() g.create_op( "FloatOutput", [], [dtypes.float32], name="FloatOutput_{}".format(self._replica_id)) g = ops.Graph() # If `switch_to_thread` isn't called, then device placement of the ops # below is not deterministic. g.switch_to_thread_local() threads = [DeviceSettingThread(g, i) for i in range(3)] for t in threads: t.start() t.has_mutated_graph.wait() t.has_mutated_graph.clear() for t in threads: t.should_continue.set() t.join() gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "FloatOutput_0" op: "FloatOutput" device: "/job:worker/replica:0" } node { name: "FloatOutput_1" op: "FloatOutput" device: "/job:worker/replica:1" } node { name: "FloatOutput_2" op: "FloatOutput" device: "/job:worker/replica:2" } """, gd) def testColocateWith(self): class ColocatingThread(self.TestThread): def __init__(self, graph, replica_id, op_to_colocate_with): super(ColocatingThread, self).__init__(graph, replica_id) self._op_to_colocate_with = op_to_colocate_with def run(self): with g.colocate_with(self._op_to_colocate_with): self.has_mutated_graph.set() self.should_continue.wait() self.should_continue.clear() g.create_op( "FloatOutput", [], [dtypes.float32], name="FloatOutput_{}".format(self._replica_id)) g = ops.Graph() ops_to_colocate_with = [] for i in range(3): with g.device("/job:worker/replica:{}".format(i)): ops_to_colocate_with.append( g.create_op( "FloatOutput", [], [dtypes.float32], name="ColocateWithMe_{}".format(i))) # If `switch_to_thread` isn't called, then `device` and `attr` values for # the ops below are not deterministic. g.switch_to_thread_local() threads = [ ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3) ] for t in threads: t.start() t.has_mutated_graph.wait() t.has_mutated_graph.clear() for t in threads: t.should_continue.set() t.join() gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "ColocateWithMe_0" op: "FloatOutput" device: "/job:worker/replica:0" } node { name: "ColocateWithMe_1" op: "FloatOutput" device: "/job:worker/replica:1" } node { name: "ColocateWithMe_2" op: "FloatOutput" device: "/job:worker/replica:2" } node { name: "FloatOutput_0" op: "FloatOutput" device: "/job:worker/replica:0" attr { key: "_class" value { list { s: "loc:@ColocateWithMe_0"}}}} node { name: "FloatOutput_1" op: "FloatOutput" device: "/job:worker/replica:1" attr { key: "_class" value { list { s: "loc:@ColocateWithMe_1"}}}} node { name: "FloatOutput_2" op: "FloatOutput" device: "/job:worker/replica:2" attr { key: "_class" value { list { s: "loc:@ColocateWithMe_2"}}}} """, gd) def testControlDependencies(self): class DependingThread(self.TestThread): def __init__(self, graph, replica_id, dependency_op): super(DependingThread, self).__init__(graph, replica_id) self._dependency_op = dependency_op def run(self): with g.control_dependencies([self._dependency_op]): self.has_mutated_graph.set() self.should_continue.wait() self.should_continue.clear() g.create_op( "FloatOutput", [], [dtypes.float32], name="FloatOutput_{}".format(self._replica_id)) g = ops.Graph() dependency_ops = [] for i in range(3): dependency_ops.append( g.create_op( "FloatOutput", [], [dtypes.float32], name="ColocateWithMe_{}".format(i))) # If `switch_to_thread` isn't called, then `input` values for the ops below # are not deterministic. g.switch_to_thread_local() threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)] for t in threads: t.start() t.has_mutated_graph.wait() t.has_mutated_graph.clear() for t in threads: t.should_continue.set() t.join() gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "ColocateWithMe_0" op: "FloatOutput" } node { name: "ColocateWithMe_1" op: "FloatOutput" } node { name: "ColocateWithMe_2" op: "FloatOutput" } node { name: "FloatOutput_0" op: "FloatOutput" input: "^ColocateWithMe_0" } node { name: "FloatOutput_1" op: "FloatOutput" input: "^ColocateWithMe_1" } node { name: "FloatOutput_2" op: "FloatOutput" input: "^ColocateWithMe_2" } """, gd) def testNameStack(self): class NameSettingThread(self.TestThread): def run(self): with g.name_scope("foo"): op1 = g.create_op("FloatOutput", [], [dtypes.float32]) self.has_mutated_graph.set() self.should_continue.wait() self.should_continue.clear() op2 = g.create_op("FloatOutput", [], [dtypes.float32]) self.result = (op1, op2) g = ops.Graph() threads = [NameSettingThread(g, i) for i in range(3)] for t in threads: t.start() t.has_mutated_graph.wait() t.has_mutated_graph.clear() for t in threads: t.should_continue.set() t.join() suffixes = ["", "_1", "_2"] for t, s in zip(threads, suffixes): self.assertEquals("foo" + s + "/FloatOutput", t.result[0].name) self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name) class ObjectWithName(object): def __init__(self, name): self._name = name @property def name(self): return self._name class CollectionTest(test_util.TensorFlowTestCase): def test_get_collections(self): g = ops.Graph() self.assertSequenceEqual(g.collections, []) g.add_to_collection("key", 12) g.add_to_collection("key", 15) self.assertSequenceEqual(g.collections, ["key"]) g.add_to_collection("other", "foo") self.assertSequenceEqual(sorted(g.collections), ["key", "other"]) def test_add_to_collection(self): g = ops.Graph() g.add_to_collection("key", 12) g.add_to_collection("other", "foo") g.add_to_collection("key", 34) # Note that only blank1 is returned. g.add_to_collection("blah", 27) blank1 = ObjectWithName("prefix/foo") g.add_to_collection("blah", blank1) blank2 = ObjectWithName("junk/foo") g.add_to_collection("blah", blank2) self.assertEqual([12, 34], g.get_collection("key")) self.assertEqual([], g.get_collection("nothing")) self.assertEqual([27, blank1, blank2], g.get_collection("blah")) self.assertEqual([blank1], g.get_collection("blah", "prefix")) self.assertEqual([blank1], g.get_collection("blah", ".*x")) # Make sure that get_collection() returns a first-level # copy of the collection, while get_collection_ref() returns # the original list. other_collection_snapshot = g.get_collection("other") other_collection_ref = g.get_collection_ref("other") self.assertEqual(["foo"], other_collection_snapshot) self.assertEqual(["foo"], other_collection_ref) g.add_to_collection("other", "bar") self.assertEqual(["foo"], other_collection_snapshot) self.assertEqual(["foo", "bar"], other_collection_ref) self.assertEqual(["foo", "bar"], g.get_collection("other")) self.assertTrue(other_collection_ref is g.get_collection_ref("other")) # Verify that getting an empty collection ref returns a modifiable list. empty_coll_ref = g.get_collection_ref("empty") self.assertEqual([], empty_coll_ref) empty_coll = g.get_collection("empty") self.assertEqual([], empty_coll) self.assertFalse(empty_coll is empty_coll_ref) empty_coll_ref2 = g.get_collection_ref("empty") self.assertTrue(empty_coll_ref2 is empty_coll_ref) # Add to the collection. empty_coll_ref.append("something") self.assertEqual(["something"], empty_coll_ref) self.assertEqual(["something"], empty_coll_ref2) self.assertEqual([], empty_coll) self.assertEqual(["something"], g.get_collection("empty")) empty_coll_ref3 = g.get_collection_ref("empty") self.assertTrue(empty_coll_ref3 is empty_coll_ref) def test_add_to_collections_uniquify(self): g = ops.Graph() g.add_to_collections([1, 2, 1], "key") # Make sure "key" is not added twice self.assertEqual(["key"], g.get_collection(1)) def test_add_to_collections_from_list(self): g = ops.Graph() g.add_to_collections(["abc", "123"], "key") self.assertEqual(["key"], g.get_collection("abc")) self.assertEqual(["key"], g.get_collection("123")) def test_add_to_collections_from_tuple(self): g = ops.Graph() g.add_to_collections(("abc", "123"), "key") self.assertEqual(["key"], g.get_collection("abc")) self.assertEqual(["key"], g.get_collection("123")) def test_add_to_collections_from_generator(self): g = ops.Graph() def generator(): yield "abc" yield "123" g.add_to_collections(generator(), "key") self.assertEqual(["key"], g.get_collection("abc")) self.assertEqual(["key"], g.get_collection("123")) def test_add_to_collections_from_set(self): g = ops.Graph() g.add_to_collections(set(["abc", "123"]), "key") self.assertEqual(["key"], g.get_collection("abc")) self.assertEqual(["key"], g.get_collection("123")) def test_add_to_collections_from_string(self): g = ops.Graph() g.add_to_collections("abc", "key") self.assertEqual(["key"], g.get_collection("abc")) def test_default_graph(self): with ops.Graph().as_default(): ops.add_to_collection("key", 90) ops.add_to_collection("key", 100) # Collections are ordered. self.assertEqual([90, 100], ops.get_collection("key")) def test_defun(self): with context.eager_mode(): @eager_function.defun def defun(): ops.add_to_collection("int", 1) ops.add_to_collection("tensor", constant_op.constant(2)) @eager_function.defun def inner_defun(): self.assertEqual(ops.get_collection("int"), [1]) three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0] ops.add_to_collection("int", 2) self.assertEqual(ops.get_collection("int"), [1, 2]) ops.add_to_collection("foo", "bar") self.assertEqual(ops.get_collection("foo"), ["bar"]) return three self.assertEqual(ops.get_collection("int"), [1]) three = inner_defun() self.assertEqual(ops.get_collection("int"), [1, 2]) self.assertEqual(ops.get_collection("foo"), ["bar"]) return three three = defun() self.assertEqual(three.numpy(), 3) ops.NotDifferentiable("FloatOutput") @ops.RegisterGradient("CopyOp") def _CopyGrad(op, x_grad): # pylint: disable=invalid-name _ = op return x_grad @ops.RegisterGradient("copy_override") def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name _ = op return x_grad class RegistrationTest(test_util.TensorFlowTestCase): def testRegisterGradients(self): x = test_ops.float_output() y = test_ops.copy_op(x) fn = ops.get_gradient_function(y.op) self.assertEqual(_CopyGrad, fn) def testOverrideGradients(self): g = ops.Graph() with g.as_default(): x = test_ops.float_output() with g.gradient_override_map({"CopyOp": "copy_override"}): y = test_ops.copy_op(x) fn = ops.get_gradient_function(y.op) self.assertEqual(_CopyOverrideGrad, fn) def testNonExistentOverride(self): g = ops.Graph() with g.as_default(): x = test_ops.float_output() with g.gradient_override_map({"CopyOp": "unknown_override"}): y = test_ops.copy_op(x) with self.assertRaisesRegexp(LookupError, "unknown_override"): ops.get_gradient_function(y.op) class ComparisonTest(test_util.TensorFlowTestCase): def testMembershipAllowed(self): g = ops.Graph() t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2") self.assertTrue(isinstance(t1, ops.Tensor)) self.assertTrue(isinstance(t2, ops.Tensor)) self.assertTrue(t1 in [t1]) self.assertTrue(t1 not in [t2]) class ControlDependenciesTest(test_util.TensorFlowTestCase): def testBasic(self): g = ops.Graph() with g.as_default(): # Creating unregistered ops with _apply_op() doesn't work with the C API # TODO(skyewm): address this more consistently. Possible solutions are # to use registered ops in all tests, create a way to register ops in # Python tests, or conditionally disable the op registration check in # the C API. a = constant_op.constant(1.0) b = constant_op.constant(1.0) with g.control_dependencies([a]): c = constant_op.constant(1.0) d = array_ops.identity(b) e = array_ops.identity(c) self.assertEqual(c.op.control_inputs, [a.op]) self.assertEqual(d.op.control_inputs, [a.op]) # e should be dominated by c. self.assertEqual(e.op.control_inputs, []) @test_util.run_in_graph_and_eager_modes def testEager(self): def future(): future.calls += 1 return constant_op.constant(2.0) future.calls = 0 if context.executing_eagerly(): a = constant_op.constant(1.0) b = future with ops.control_dependencies([a, b]): c = constant_op.constant(3.0) self.assertEqual(future.calls, 1) else: g = ops.Graph() with g.as_default(): a = constant_op.constant(1.0) b = future() with g.control_dependencies([a, b]): c = constant_op.constant(3.0) self.assertEqual(c.op.control_inputs, [a.op, b.op]) self.assertEqual(future.calls, 1) def testBasicWithConversion(self): g = ops.Graph() a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) class ConvertibleObj(object): def _as_graph_element(self): return a with g.control_dependencies([ConvertibleObj()]): c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) self.assertEqual(c.op.control_inputs, [a.op]) def testNested(self): g = ops.Graph() a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) with g.control_dependencies([a_1, a_2, a_3, a_4]): b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) with g.control_dependencies([a_1]): with g.control_dependencies([a_2]): with g.control_dependencies([a_3]): with g.control_dependencies([a_4]): b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op], b_1.op.control_inputs) self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) def testClear(self): g = ops.Graph() a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) with g.control_dependencies([a_1]): with g.control_dependencies([a_2]): with g.control_dependencies(None): with g.control_dependencies([a_3]): with g.control_dependencies([a_4]): # deps [a_3, a_4] b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) # deps = [a_3] b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) # deps back to None b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) # deps back to [a_1, a_2] b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) # deps back to [a_1] b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) with g.control_dependencies(None): # deps are None again b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) self.assertItemsEqual([a_3.op], b_3.op.control_inputs) self.assertItemsEqual([], b_none.op.control_inputs) self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) self.assertItemsEqual([a_1.op], b_1.op.control_inputs) self.assertItemsEqual([], b_none2.op.control_inputs) def testComplex(self): g = ops.Graph() # Usage pattern: # * Nodes a_i are constants defined at the outermost scope, and are used # as control inputs for the ith nested scope. # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) with g.control_dependencies([a_1]): b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], [dtypes.float32]) c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], [dtypes.float32]) d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1], [dtypes.float32]) e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) with g.control_dependencies([a_2]): b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], [dtypes.float32]) c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], [dtypes.float32]) d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2], [dtypes.float32]) e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1], [dtypes.float32]) with g.control_dependencies([a_3]): b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], [dtypes.float32]) c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], [dtypes.float32]) d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3], [dtypes.float32]) e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2], [dtypes.float32]) with g.control_dependencies([a_4]): b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], [dtypes.float32]) c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], [dtypes.float32]) d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4], [dtypes.float32]) e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3], [dtypes.float32]) self.assertItemsEqual([a_1.op], b_1.op.control_inputs) self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) self.assertItemsEqual([], c_1.op.control_inputs) self.assertItemsEqual([a_2.op], c_2.op.control_inputs) self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) self.assertItemsEqual([], d_1.op.control_inputs) self.assertItemsEqual([], d_2.op.control_inputs) self.assertItemsEqual([], d_3.op.control_inputs) self.assertItemsEqual([], d_4.op.control_inputs) self.assertItemsEqual([a_1.op], e_1.op.control_inputs) self.assertItemsEqual([a_2.op], e_2.op.control_inputs) self.assertItemsEqual([a_3.op], e_3.op.control_inputs) self.assertItemsEqual([a_4.op], e_4.op.control_inputs) def testRepeatedDependency(self): g = ops.Graph() a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) a_0, a_1 = a.outputs with g.control_dependencies([a_0]): b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) with g.control_dependencies([a_1]): c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) self.assertEqual(b.op.control_inputs, [a]) self.assertEqual(c.op.control_inputs, [a]) def testNoControlDependencyWithDataDependency(self): g = ops.Graph() a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) with g.control_dependencies([a]): b = _apply_op(g, "Identity", [a], [dtypes.float32]) self.assertEqual(b.op.control_inputs, []) class OpScopeTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes def testNames(self): with ops.name_scope("foo") as foo: self.assertEqual("foo/", foo) with ops.name_scope("foo2") as foo2: self.assertEqual("foo/foo2/", foo2) with ops.name_scope(None) as empty1: self.assertEqual("", empty1) with ops.name_scope("foo3") as foo3: self.assertEqual("foo3/", foo3) with ops.name_scope("") as empty2: self.assertEqual("", empty2) with ops.name_scope("foo/") as outer_foo: self.assertEqual("foo/", outer_foo) with ops.name_scope("") as empty3: self.assertEqual("", empty3) with ops.name_scope("foo4") as foo4: self.assertEqual("foo/foo4/", foo4) with ops.name_scope("foo5//") as foo5: self.assertEqual("foo5//", foo5) with ops.name_scope("foo6") as foo6: self.assertEqual("foo5//foo6/", foo6) with ops.name_scope("/") as foo7: self.assertEqual("/", foo7) with ops.name_scope("//") as foo8: self.assertEqual("//", foo8) with ops.name_scope("a//b/c") as foo9: self.assertEqual("foo/a//b/c/", foo9) with ops.name_scope("a//b/c") as foo10: self.assertEqual("a//b/c/", foo10) @test_util.run_in_graph_and_eager_modes def testEagerDefaultScopeName(self): with ops.name_scope(None, "default") as scope: self.assertEqual(scope, "default/") with ops.name_scope(None, "default2") as scope2: self.assertEqual(scope2, "default/default2/") def testNoScopeName(self): g0 = ops.Graph() values = [ g0.create_op("A", [], [dtypes.float32]), g0.create_op("B", [], [dtypes.float32]) ] with self.assertRaises(ValueError): with ops.name_scope(None, values=values): pass with self.assertRaises(ValueError): with ops.name_scope(None, None, values): pass def testEmptyScopeName(self): g0 = ops.Graph() a = g0.create_op("A", [], [dtypes.float32]) b = g0.create_op("B", [], [dtypes.float32]) with ops.name_scope("", values=[a, b]) as scope: self.assertEqual("", scope) self.assertEqual(g0, ops.get_default_graph()) with ops.name_scope("", "my_default_scope", [a, b]) as scope: self.assertEqual("", scope) self.assertEqual(g0, ops.get_default_graph()) def testDefaultScopeName(self): g0 = ops.Graph() a = g0.create_op("A", [], [dtypes.float32]) b = g0.create_op("B", [], [dtypes.float32]) scope_name = "my_scope" default_scope_name = "my_default_scope" with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope: self.assertEqual("%s/" % scope_name, scope) self.assertEqual(g0, ops.get_default_graph()) with ops.name_scope(None, default_scope_name, [a, b]) as scope: self.assertEqual("%s/" % default_scope_name, scope) self.assertEqual(g0, ops.get_default_graph()) def _testGraphElements(self, graph_elements): scope_name = "my_scope" with ops.name_scope(scope_name, values=graph_elements) as scope: self.assertEqual("%s/" % scope_name, scope) self.assertEqual(graph_elements[0].graph, ops.get_default_graph()) g1 = ops.Graph() a = g1.create_op("A", [], [dtypes.float32]) with self.assertRaises(ValueError): with ops.name_scope(scope_name, values=graph_elements + [a]): pass def testTensor(self): g0 = ops.Graph() a = g0.create_op("A", [], [dtypes.float32]) b = g0.create_op("B", [], [dtypes.float32]) self._testGraphElements([a, b]) def testSparseTensor(self): g0 = ops.Graph() a = g0.create_op("A", [], [dtypes.float32]) b = g0.create_op("B", [], [dtypes.float32]) sparse = sparse_tensor.SparseTensor( _apply_op(g0, "Int64Output", [], [dtypes.int64]), _apply_op(g0, "FloatOutput", [], [dtypes.float32]), _apply_op(g0, "Int64Output", [], [dtypes.int64])) self._testGraphElements([a, sparse, b]) def testVariable(self): g0 = ops.Graph() with g0.as_default(): variable = variables.Variable([1.0]) a = g0.create_op("A", [], [dtypes.float32]) b = g0.create_op("B", [], [dtypes.float32]) self._testGraphElements([a, variable, b]) class InitScopeTest(test_util.TensorFlowTestCase): def testClearsControlDependencies(self): g = ops.Graph() a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) with g.as_default(): with g.control_dependencies([a_1]): with g.control_dependencies([a_2]): with ops.init_scope(): with g.control_dependencies([a_3]): with g.control_dependencies([a_4]): # deps [a_3, a_4] b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) # deps = [a_3] b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) # deps back to None b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) # deps back to [a_1, a_2] b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) # deps back to [a_1] b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) with ops.init_scope(): # deps are None again b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) self.assertItemsEqual([a_3.op], b_3.op.control_inputs) self.assertItemsEqual([], b_none.op.control_inputs) self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) self.assertItemsEqual([a_1.op], b_1.op.control_inputs) self.assertItemsEqual([], b_none2.op.control_inputs) def testLiftsOpsFromFunctions(self): g0 = ops.Graph() g1 = ops.Graph() g1._building_function = True # pylint: disable=protected-access g2 = ops.Graph() g2._building_function = True # pylint: disable=protected-access with g0.as_default(): with g1.as_default(): with g2.as_default(): with ops.init_scope(): _ = constant_op.constant(1.0) self.assertEqual(len(g2.get_operations()), 0) self.assertEqual(len(g1.get_operations()), 0) self.assertEqual(len(g0.get_operations()), 1) def testPreservesDevices(self): g0 = ops.Graph() with g0.as_default(), ops.device("CPU:0"): g1 = ops.Graph() g1._building_function = True # pylint: disable=protected-access with g1.as_default(), ops.device("GPU:0"): with ops.init_scope(): # init_scope should preserve device set under `g1`. on_gpu = constant_op.constant(1.0) self.assertEqual(on_gpu.device, "/device:GPU:0") still_on_gpu = constant_op.constant(1.0) self.assertEqual(still_on_gpu.device, "/device:GPU:0") on_cpu = constant_op.constant(1.0) self.assertEqual(on_cpu.device, "/device:CPU:0") def testComposes(self): g0 = ops.Graph() g1 = ops.Graph() g1._building_function = True # pylint: disable=protected-access g2 = ops.Graph() g2._building_function = True # pylint: disable=protected-access g3 = ops.Graph() g3._building_function = False # pylint: disable=protected-access with g0.as_default(): with g1.as_default(): with ops.init_scope(): # This op should be lifted into g0. _ = constant_op.constant(1.0) self.assertIs(g0, ops.get_default_graph()) self.assertEqual(len(g2.get_operations()), 0) self.assertEqual(len(g1.get_operations()), 0) self.assertEqual(len(g0.get_operations()), 1) with g2.as_default(): with ops.init_scope(): # This op should be lifted into g0. _ = constant_op.constant(1.0) self.assertIs(g0, ops.get_default_graph()) with g3.as_default(): with ops.init_scope(): # This op should be lifted into g3, because g3 is not building a # function. _ = constant_op.constant(1.0) self.assertIs(g3, ops.get_default_graph()) self.assertEqual(len(g3.get_operations()), 1) self.assertEqual(len(g2.get_operations()), 0) self.assertEqual(len(g1.get_operations()), 0) self.assertEqual(len(g0.get_operations()), 2) def testEscapesToEagerContext(self): g = ops.Graph() g._building_function = True # pylint: disable=protected-access with context.eager_mode(): with context.graph_mode(): with g.as_default(): with ops.init_scope(): # Because g is building a function, init_scope should # escape out to the eager context. self.assertTrue(context.executing_eagerly()) # g should be reinstated as the default graph, and the # graph context should be re-entered. self.assertIs(g, ops.get_default_graph()) self.assertFalse(context.executing_eagerly()) def testStaysInEagerWhenOnlyEagerContextActive(self): with context.eager_mode(): with ops.init_scope(): self.assertTrue(context.eager_mode()) self.assertTrue(context.eager_mode()) def testEscapesDefunWhenInEagerMode(self): def function_with_variables(): with ops.init_scope(): self.v = resource_variable_ops.ResourceVariable(3) return self.v.assign_add(1) with context.eager_mode(): # Each invocation of function_with_variables recreates a variable. self.assertEqual(4, int(function_with_variables())) self.assertEqual(4, int(function_with_variables())) compiled = eager_function.defun(function_with_variables) # The init_scope in function_with_variables lifts the variable out # of the graph function constructed by defun; hence, # compiled now appears to be stateful. self.assertEqual(4, int(compiled())) self.assertEqual(5, int(compiled())) def testEscapesDefunWhenInGraphMode(self): def function_with_variables(name): with ops.init_scope(): _ = variable_scope.get_variable(name, shape=(1,)) g = ops.Graph() with g.as_default(): with self.cached_session(): # First ensure that graphs that are not building functions are # not escaped. function_with_variables("foo") with self.assertRaisesRegexp(ValueError, r"Variable foo already exists.*"): # This will fail because reuse is not set to True. function_with_variables("foo") compiled = eager_function.defun(function_with_variables) compiled("bar") self.assertEqual( len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) # The second call to `compiled` should not create variables: the # init_scope has lifted the variable creation code out of the defun. compiled("bar") self.assertEqual( len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) def testEscapesNestedDefun(self): def inner_function(): with ops.init_scope(): self.v = resource_variable_ops.ResourceVariable(1) return self.v.assign_add(2) def outer_function(inner=None): with ops.init_scope(): self.v0 = resource_variable_ops.ResourceVariable(0) return self.v0.assign_add(1) + inner() with context.eager_mode(): # Each invocation of outer_function recreates variables. self.assertEqual(4, int(outer_function(inner=inner_function))) self.assertEqual(4, int(outer_function(inner=inner_function))) compiled_inner = eager_function.defun(inner_function) compiled_outer = eager_function.defun(outer_function) # The init_scope lifts variables out of the graph functions # constructed by defun; hence, compiled_outer should now appear to be # stateful. self.assertEqual(4, int(compiled_outer(inner=compiled_inner))) self.assertEqual(7, int(compiled_outer(inner=compiled_inner))) def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self): with context.graph_mode(): ops.reset_default_graph() # This doesn't push anything onto the graph stack, but it does # set the stack's global graph. global_graph = ops.get_default_graph() fn_graph = ops.Graph() # pylint: disable=protected-access fn_graph._building_function = True self.assertEqual(len(ops._default_graph_stack.stack), 0) with fn_graph.as_default(): self.assertEqual(len(ops._default_graph_stack.stack), 1) with ops.init_scope(): self.assertGreater(len(ops._default_graph_stack.stack), 1) dummy = constant_op.constant(1.0) self.assertEqual(len(ops._default_graph_stack.stack), 1) # Note that the global graph is _not_ on the graph stack. self.assertEqual(len(ops._default_graph_stack.stack), 0) # Ensure that `dummy` was added to the global graph. self.assertEqual(global_graph, dummy.graph) # pylint: enable=protected-access def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self): with context.graph_mode(): # pylint: disable=protected-access self.assertEqual(len(ops._default_graph_stack.stack), 0) with ops.init_scope(): self.assertGreater(len(ops._default_graph_stack.stack), 0) self.assertEqual(len(ops._default_graph_stack.stack), 0) # pylint: enable=protected-access def testPreservesNameScopeInGraphConstruction(self): with ops.Graph().as_default(): function_graph = ops.Graph() with function_graph.as_default(): with ops.name_scope("inner"), ops.init_scope(): self.assertEqual(ops.get_name_scope(), "inner") self.assertEqual(ops.get_name_scope(), "") def testEnteringGraphFromEagerIsSticky(self): with context.eager_mode(): g = ops.Graph() with g.as_default(): with ops.init_scope(): self.assertFalse(context.executing_eagerly()) self.assertEqual(g, ops.get_default_graph()) def testMixGraphEager(self): with context.eager_mode(): c = constant_op.constant(1.0) with ops.Graph().as_default(): with self.assertRaisesRegexp( RuntimeError, "Attempting to capture an EagerTensor"): math_ops.add(c, c) c2 = constant_op.constant(2.0) with self.assertRaisesRegexp( TypeError, "contains objects other than 'EagerTensor'"): math_ops.add(c2, c2) def testPreservesNameScopeInEagerExecution(self): with context.eager_mode(): def foo(): with ops.name_scope("inner"), ops.init_scope(): if context.executing_eagerly(): # A trailing slash is always appended when eager execution is # enabled. self.assertEqual(context.context().scope_name, "inner/") else: self.assertEqual(ops.get_name_scope(), "inner") foo() self.assertEqual(ops.get_name_scope(), "") foo_compiled = eager_function.defun(foo) foo_compiled() self.assertEqual(ops.get_name_scope(), "") class GraphTest(test_util.TensorFlowTestCase): def setUp(self): ops.reset_default_graph() def _AssertDefault(self, expected): self.assertIs(expected, ops.get_default_graph()) def testResetDefaultGraphNesting(self): g0 = ops.Graph() with self.assertRaises(AssertionError): with g0.as_default(): ops.reset_default_graph() def testGraphContextManagerCancelsEager(self): with context.eager_mode(): with ops.Graph().as_default(): self.assertFalse(context.executing_eagerly()) def testGraphContextManager(self): g0 = ops.Graph() with g0.as_default() as g1: self.assertIs(g0, g1) def testDefaultGraph(self): orig = ops.get_default_graph() self._AssertDefault(orig) g0 = ops.Graph() self._AssertDefault(orig) context_manager_0 = g0.as_default() self._AssertDefault(orig) with context_manager_0 as g0: self._AssertDefault(g0) with ops.Graph().as_default() as g1: self._AssertDefault(g1) self._AssertDefault(g0) self._AssertDefault(orig) def testPreventFeeding(self): g = ops.Graph() a = constant_op.constant(2.0) self.assertTrue(g.is_feedable(a)) g.prevent_feeding(a) self.assertFalse(g.is_feedable(a)) def testPreventFetching(self): g = ops.Graph() a = constant_op.constant(2.0) self.assertTrue(g.is_fetchable(a)) g.prevent_fetching(a.op) self.assertFalse(g.is_fetchable(a)) def testAsGraphElementConversions(self): class ConvertibleObj(object): def _as_graph_element(self): return "FloatOutput:0" class NonConvertibleObj(object): pass g = ops.Graph() a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) self.assertEqual(a, g.as_graph_element(ConvertibleObj())) with self.assertRaises(TypeError): g.as_graph_element(NonConvertibleObj()) # Regression test against creating custom __del__ functions in classes # involved in cyclic references, e.g. Graph and Operation. (Python won't gc # cycles that require calling a __del__ method, because the __del__ method can # theoretically increase the object's refcount to "save" it from gc, and any # already-deleted objects in the cycle would have be to restored.) def testGarbageCollected(self): # Create a graph we can delete and a weak reference to monitor if it's gc'd g = ops.Graph() g_ref = weakref.ref(g) # Create some ops with g.as_default(): a = constant_op.constant(2.0) b = constant_op.constant(3.0) c = math_ops.add(a, b) # Create a session we can delete with session.Session(graph=g) as sess: sess.run(c) # Delete all references and trigger gc del g del a del b del c del sess gc.collect() self.assertIsNone(g_ref()) def testRunnableAfterInvalidShape(self): with ops.Graph().as_default(): with self.assertRaises(ValueError): math_ops.add([1, 2], [1, 2, 3]) a = constant_op.constant(1) with session.Session() as sess: sess.run(a) def testRunnableAfterInvalidShapeWithKernelLabelMap(self): g = ops.Graph() with g.as_default(): with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): with self.assertRaises(ValueError): test_ops.kernel_label_required(1) a = constant_op.constant(1) with session.Session() as sess: sess.run(a) class AttrScopeTest(test_util.TensorFlowTestCase): def _get_test_attrs(self): x = control_flow_ops.no_op() try: a = compat.as_text(x.get_attr("_A")) except ValueError: a = None try: b = compat.as_text(x.get_attr("_B")) except ValueError: b = None return (a, b) def testNoLabel(self): with self.cached_session(): self.assertAllEqual((None, None), self._get_test_attrs()) def testLabelMap(self): with self.cached_session() as sess: a1 = self._get_test_attrs() with sess.graph._attr_scope({ "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo")) }): a2 = self._get_test_attrs() with sess.graph._attr_scope({ "_A": None, "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar")) }): a3 = self._get_test_attrs() with sess.graph._attr_scope({ "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz")) }): a4 = self._get_test_attrs() a5 = self._get_test_attrs() a6 = self._get_test_attrs() a7 = self._get_test_attrs() self.assertAllEqual((None, None), a1) self.assertAllEqual(("foo", None), a2) self.assertAllEqual((None, "bar"), a3) self.assertAllEqual(("baz", "bar"), a4) self.assertAllEqual((None, "bar"), a5) self.assertAllEqual(("foo", None), a6) self.assertAllEqual((None, None), a7) ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) class KernelLabelTest(test_util.TensorFlowTestCase): def testNoLabel(self): with self.cached_session(): self.assertAllEqual(b"My label is: default", test_ops.kernel_label().eval()) def testLabelMap(self): with self.cached_session() as sess: default_1 = test_ops.kernel_label() # pylint: disable=protected-access with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): overload_1_1 = test_ops.kernel_label() with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}): overload_2 = test_ops.kernel_label() with sess.graph._kernel_label_map({"KernelLabel": ""}): default_2 = test_ops.kernel_label() overload_1_2 = test_ops.kernel_label() # pylint: enable=protected-access default_3 = test_ops.kernel_label() self.assertAllEqual(b"My label is: default", default_1.eval()) self.assertAllEqual(b"My label is: default", default_2.eval()) self.assertAllEqual(b"My label is: default", default_3.eval()) self.assertAllEqual(b"My label is: overload_1", overload_1_1.eval()) self.assertAllEqual(b"My label is: overload_1", overload_1_2.eval()) self.assertAllEqual(b"My label is: overload_2", overload_2.eval()) class AsGraphDefTest(test_util.TensorFlowTestCase): def testGraphDefVersion(self): """Test that the graphdef version is plumbed through to kernels.""" with ops.Graph().as_default() as g: version = g.graph_def_versions.producer with self.session(graph=g): v = test_ops.graph_def_version().eval() self.assertEqual(version, v) def testAddShapes(self): with ops.Graph().as_default() as g: t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [], [dtypes.float32] * 5) t1.set_shape(None) t2.set_shape([]) t3.set_shape([None]) t4.set_shape([43, 37]) t5.set_shape([43, None]) b = constant_op.constant(1.0) # pylint: disable=unused-variable gd = g.as_graph_def(add_shapes=True) self.assertProtoEqualsVersion(""" node { name: "FiveFloatOutputs" op: "FiveFloatOutputs" attr { key: "_output_shapes" value { list { shape { unknown_rank: true } shape { } shape { dim { size: -1 } } shape { dim { size: 43 } dim { size: 37 } } shape { dim { size: 43 } dim { size: -1 } } } } } } node { name: "Const" op: "Const" attr { key: "_output_shapes" value { list { shape { } } } } attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape { } float_val: 1.0 } } } } """, gd) @ops.RegisterStatistics("a", "flops") def _calc_a_forward_flops(unused_graph, unused_node): return ops.OpStats("flops", 20) class StatisticsTest(test_util.TensorFlowTestCase): def testRegisteredNode(self): graph = ops.Graph() node = ops._NodeDef("a", "an_a") flops = ops.get_stats_for_node_def(graph, node, "flops") self.assertEqual(20, flops.value) missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat") self.assertEqual(None, missing_stat.value) def testUnregisteredNode(self): graph = ops.Graph() node = ops._NodeDef("b", "a_b") weight_params = ops.get_stats_for_node_def(graph, node, "weight_params") self.assertEqual(None, weight_params.value) def testAccumulateStatistics(self): flops_total = ops.OpStats("flops") self.assertEqual(None, flops_total.value) second_flops = ops.OpStats("flops", 3) flops_total += second_flops self.assertEqual(3, flops_total.value) class DeviceStackTest(test_util.TensorFlowTestCase): def testBasicDeviceAssignmentMetadata(self): def device_func(unused_op): return "/cpu:*" const_zero = constant_op.constant([0.0], name="zero") with ops.device("/cpu"): const_one = constant_op.constant([1.0], name="one") with ops.device("/cpu:0"): const_two = constant_op.constant([2.0], name="two") with ops.device(device_func): const_three = constant_op.constant(3.0, name="three") self.assertEqual(0, len(const_zero.op._device_assignments)) one_list = const_one.op._device_assignments self.assertEqual(1, len(one_list)) self.assertEqual("/cpu", one_list[0].obj) self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename)) two_list = const_two.op._device_assignments self.assertEqual(2, len(two_list)) devices = [t.obj for t in two_list] self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices)) three_list = const_three.op._device_assignments self.assertEqual(1, len(three_list)) func_description = three_list[0].obj expected_regex = r"device_func<.*ops_test.py, [0-9]+" self.assertRegexpMatches(func_description, expected_regex) def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self): with ops.device("/cpu"): const_one = constant_op.constant([1.0], name="one") with ops.get_default_graph().device("/cpu"): const_two = constant_op.constant([2.0], name="two") one_metadata = const_one.op._device_assignments[0] two_metadata = const_two.op._device_assignments[0] # Verify both types of device assignment return the right stack info. self.assertRegexpMatches("ops_test.py", os.path.basename(one_metadata.filename)) self.assertEqual(one_metadata.filename, two_metadata.filename) self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno) class ColocationGroupTest(test_util.TensorFlowTestCase): def testBasic(self): a = constant_op.constant([2.0], name="a") with ops.colocate_with(a.op): b = constant_op.constant(3.0) c = constant_op.constant(4.0) self.assertEqual([b"loc:@a"], a.op.colocation_groups()) self.assertEqual([b"loc:@a"], b.op.colocation_groups()) with self.assertRaises(ValueError): c.op.get_attr("_class") def testBasicColocationMetadata(self): const_two = constant_op.constant([2.0], name="two") with ops.colocate_with(const_two.op): const_three = constant_op.constant(3.0, name="three") locations_dict = const_three.op._colocation_dict self.assertIn("two", locations_dict) metadata = locations_dict["two"] self.assertIsNone(metadata.obj) # Check that this test's filename is recorded as the file containing the # colocation statement. self.assertEqual("ops_test.py", os.path.basename(metadata.filename)) def testColocationDeviceInteraction(self): with ops.device("/cpu:0"): with ops.device("/device:GPU:0"): a = constant_op.constant([2.0], name="a") with ops.colocate_with(a.op): # 'b' is created in the scope of /cpu:0, but it is # colocated with 'a', which is on '/device:GPU:0'. colocate_with # overrides devices because it is a stronger constraint. b = constant_op.constant(3.0) self.assertEqual([b"loc:@a"], b.op.colocation_groups()) self.assertEqual(a.op.device, b.op.device) def testColocationCanonicalization(self): with ops.device("/device:GPU:0"): _ = constant_op.constant(2.0) with ops.device(lambda op: "/device:GPU:0"): b = constant_op.constant(3.0) with ops.get_default_graph().colocate_with(b): with ops.device("/device:GPU:0"): c = constant_op.constant(4.0) # A's device will be /device:GPU:0 # B's device will be /device:GPU:0 # C's device will be /device:GPU:0 because it # inherits B's device name, after canonicalizing the names. self.assertEqual(b.op.device, c.op.device) def testLocationOverrides(self): with ops.device("/cpu:0"): with ops.device("/device:GPU:0"): a = constant_op.constant([2.0], name="a") # Note that this colocation is "redundant", since we are # within the scope of "/device:GPU:0". However, we would like to # preserve in the GraphDef that these two ops should be # colocated in a portable way. with ops.colocate_with(a.op): b = constant_op.constant(3.0) c = constant_op.constant(4.0) d = constant_op.constant(5.0) self.assertEqual([b"loc:@a"], b.op.colocation_groups()) self.assertEqual("/device:GPU:0", a.op.device) self.assertEqual(a.op.device, b.op.device) # Test that device function stack is restored. self.assertEqual("/device:GPU:0", c.op.device) self.assertEqual("/device:CPU:0", d.op.device) def testNestedColocateWith(self): a = constant_op.constant([2.0], name="a") with ops.colocate_with(a.op): b = constant_op.constant(3.0) with ops.colocate_with(b.op): c = constant_op.constant(4.0) self.assertEqual([b"loc:@a"], b.op.colocation_groups()) self.assertEqual([b"loc:@a"], c.op.colocation_groups()) def testMultiColocationGroups(self): a = constant_op.constant([2.0], name="a") b = constant_op.constant(3.0, name="b") with ops.colocate_with(a.op): with ops.colocate_with(b.op): c = constant_op.constant(4.0) self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups())) def testColocationIgnoreStack(self): a = constant_op.constant([2.0], name="a") b = constant_op.constant(3.0, name="b") with ops.colocate_with(a.op): with ops.colocate_with(b.op, ignore_existing=True): c = constant_op.constant(4.0) self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups())) def testColocateWithReset(self): a = constant_op.constant([2.0], name="a") with ops.colocate_with(a.op): b = constant_op.constant(3.0, name="b") with ops.colocate_with(None, ignore_existing=True): c = constant_op.constant(4.0, name="c") self.assertEqual([b"loc:@a"], b.op.colocation_groups()) self.assertEqual([b"loc:@c"], c.op.colocation_groups()) def testColocateWithInitialNoneThenNested(self): a = constant_op.constant([2.0], name="a") with ops.colocate_with(a.op): with ops.colocate_with(None, ignore_existing=True): b = constant_op.constant(3.0, name="b") with ops.colocate_with(b.op): c = constant_op.constant(4.0, name="c") self.assertEqual([b"loc:@b"], b.op.colocation_groups()) self.assertEqual([b"loc:@b"], c.op.colocation_groups()) def testColocateVariables(self): a = variables.Variable([2.0], name="a") with ops.colocate_with(a.op): b = variables.Variable([3.0], name="b") self.assertEqual([b"loc:@a"], b.op.colocation_groups()) def testInconsistentDeviceWithinColocate(self): with ops.device("/device:GPU:0"): a = constant_op.constant([2.0], name="a") with ops.colocate_with(a.op): # This is allowed due to legacy but clearly wrong, since we # should really be colocating with 'a'. We allow devices to # override colocate_with, but we log warnings to suggest that # this is probably unintentional or misguided. with ops.device("/cpu:0"): b = constant_op.constant([3.0], name="b") self.assertEqual("/device:CPU:0", b.device) def testMakeColocationConflictMessage(self): """Test that provides an example of a complicated error message.""" # We could test the message with any ops, but this test will be more # instructive with a real colocation conflict. with ops.device("/device:GPU:0"): a = constant_op.constant([2.0], name="a") with ops.colocate_with(a.op): with ops.device("/cpu:0"): b = constant_op.constant([3.0], name="b") # The definition-location of the nodes will be wrong because of running # from within a TF unittest. The rest of the info should be correct. message = ops.get_default_graph()._make_colocation_conflict_message(a.op, b.op) self.assertRegexpMatches(message, r"Tried to colocate op 'a' \(defined at.*\)") self.assertRegexpMatches(message, "No node-device.*'a'") self.assertRegexpMatches(message, "Device assignments active.*'a'") self.assertRegexpMatches(message, "GPU:0") self.assertRegexpMatches(message, "Node-device colocations active.*'b'") self.assertRegexpMatches(message, "Device assignments active.*'b'") self.assertRegexpMatches(message, "cpu:0") class DeprecatedTest(test_util.TensorFlowTestCase): def testSuccess(self): with ops.Graph().as_default() as g: test_util.set_producer_version(g, 7) old = test_ops.old() with self.session(graph=g): old.run() def _error(self): return ((r"Op Old is not available in GraphDef version %d\. " r"It has been removed in version 8\. For reasons\.") % versions.GRAPH_DEF_VERSION) def testGraphConstructionFail(self): with ops.Graph().as_default(): with self.assertRaisesRegexp(NotImplementedError, self._error()): test_ops.old() class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase): def testSuccess(self): op = ops.Operation( ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) t = op.outputs[0] self.assertTrue(ops.is_dense_tensor_like(t)) v = variables.Variable([17]) self.assertTrue(ops.is_dense_tensor_like(v)) class BadClassNoName(object): pass class BadClassBadName(object): def name(self): pass class BadClassNoDtype(object): @property def name(self): pass class BadClassBadDtype(object): @property def name(self): pass def dtype(self): pass def testBadClass(self): with self.assertRaisesRegexp(TypeError, "`name`"): ops.register_dense_tensor_like_type( DenseTensorLikeTypeTest.BadClassNoName) with self.assertRaisesRegexp(TypeError, "`name`"): ops.register_dense_tensor_like_type( DenseTensorLikeTypeTest.BadClassBadName) with self.assertRaisesRegexp(TypeError, "`dtype`"): ops.register_dense_tensor_like_type( DenseTensorLikeTypeTest.BadClassNoDtype) with self.assertRaisesRegexp(TypeError, "`dtype`"): ops.register_dense_tensor_like_type( DenseTensorLikeTypeTest.BadClassBadDtype) class NameScopeTest(test_util.TensorFlowTestCase): def testStripAndPrependScope(self): strs = [ "hidden1/hidden1/weights", # Same prefix. Should strip. "hidden1///hidden1/weights", # Extra "/". Should strip. "^hidden1/hidden1/weights", # Same prefix. Should strip. "loc:@hidden1/hidden1/weights", # Same prefix. Should strip. "hhidden1/hidden1/weights", # Different prefix. Should keep. "hidden1" ] # Not a prefix. Should keep. expected_striped = [ "hidden1/weights", "hidden1/weights", "^hidden1/weights", "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1" ] expected_prepended = [ "hidden2/hidden1/weights", "hidden2/hidden1/weights", "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights", "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1" ] name_scope_to_strip = "hidden1" name_scope_to_add = "hidden2" for es, ep, s in zip(expected_striped, expected_prepended, strs): striped = ops.strip_name_scope(s, name_scope_to_strip) self.assertEqual(es, striped) self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add)) def testGetNameScope(self): with ops.Graph().as_default() as g: with ops.name_scope("scope1"): with ops.name_scope("scope2"): with ops.name_scope("scope3"): self.assertEqual("scope1/scope2/scope3", g.get_name_scope()) self.assertEqual("scope1/scope2", g.get_name_scope()) self.assertEqual("scope1", g.get_name_scope()) self.assertEqual("", g.get_name_scope()) def testTwoGraphs(self): def f(): g1 = ops.Graph() g2 = ops.Graph() with g1.as_default(): with g2.as_default(): with ops.name_scope("_"): pass self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f) class TracebackTest(test_util.TensorFlowTestCase): def testTracebackWithStartLines(self): with self.cached_session() as sess: a = constant_op.constant(2.0) sess.run( a, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) self.assertTrue(sess.graph.get_operations()) # Tests that traceback_with_start_lines is the same as traceback # but includes one more element at the end. for op in sess.graph.get_operations(): self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines)) for frame, frame_with_start_line in zip( op.traceback, op.traceback_with_start_lines): self.assertEquals(5, len(frame_with_start_line)) self.assertEquals(frame, frame_with_start_line[:-1]) class EnableEagerExecutionTest(test_util.TensorFlowTestCase): def testBadArgumentsToEnableEagerExecution(self): with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"): ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT) with self.assertRaisesRegexp(ValueError, "device_policy must be one of"): c = config_pb2.ConfigProto() ops.enable_eager_execution(c, c) with self.assertRaisesRegexp(ValueError, "execution_mode must be one of"): c = config_pb2.ConfigProto() ops.enable_eager_execution(c, execution_mode=c) if __name__ == "__main__": googletest.main()