diff options
Diffstat (limited to 'tensorflow/python/framework/importer_test.py')
-rw-r--r-- | tensorflow/python/framework/importer_test.py | 546 |
1 files changed, 546 insertions, 0 deletions
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py new file mode 100644 index 0000000000..470092313a --- /dev/null +++ b/tensorflow/python/framework/importer_test.py @@ -0,0 +1,546 @@ +"""Tests for tensorflow.python.framework.importer.""" + +import tensorflow.python.platform + +import tensorflow as tf + +from google.protobuf import text_format + +from tensorflow.core.framework import op_def_pb2 +from tensorflow.python.framework import device +from tensorflow.python.framework import op_def_registry + + +_op_list = op_def_pb2.OpList() +text_format.Merge(""" + op { + name: 'None' + } + op { + name: 'Oi' + output_arg { name: 'a' type: DT_INT32 } + } + op { + name: 'Or' + output_arg { name: 'a' type: DT_INT32 is_ref: true } + } + op { + name: 'Of' + output_arg { name: 'a' type: DT_FLOAT } + } + op { + name: 'Ii' + input_arg { name: 'a' type: DT_INT32 } + } + op { + name: 'If' + input_arg { name: 'a' type: DT_FLOAT } + } + op { + name: 'Oii' + output_arg { name: 'a' type: DT_INT32 } + output_arg { name: 'b' type: DT_INT32 } + } + op { + name: 'Oif' + output_arg { name: 'a' type: DT_INT32 } + output_arg { name: 'b' type: DT_FLOAT } + } + op { + name: 'Iii' + input_arg { name: 'a' type: DT_INT32 } + input_arg { name: 'b' type: DT_INT32 } + } + op { + name: 'Iff' + input_arg { name: 'a' type: DT_FLOAT } + input_arg { name: 'b' type: DT_FLOAT } + } + op { + name: 'Iif' + input_arg { name: 'a' type: DT_INT32 } + input_arg { name: 'b' type: DT_FLOAT } + } + op { + name: 'Iri' + input_arg { name: 'a' type: DT_INT32 is_ref: true } + input_arg { name: 'b' type: DT_INT32 } + } + op { + name: 'In' + input_arg { name: 'a' number_attr: 'N' type_attr: 'T' } + attr { name: 'N' type: 'int' minimum: 1 } + attr { name: 'T' type: 'type' } + } + op { + name: 'Otl' + output_arg { name: 'a' type_list_attr: 't' } + attr { name: 'T' type: 'list(type)' minimum: 1 } + } + op { + name: 'Unary' + input_arg { name: 'a' type_attr: 'T' } + output_arg { name: 'b' type_attr: 'T' } + attr { name: 'T' type: 'type' } + } +""", _op_list) +op_def_registry.register_op_list(_op_list) +# NOTE(mrry): Dummy shape registrations for ops used in the tests. +for op_def in _op_list.op: + tf.RegisterShape(op_def.name)(None) + +class ImportGraphDefTest(tf.test.TestCase): + + def _MakeGraphDef(self, text): + ret = tf.GraphDef() + text_format.Merge(text, ret) + return ret + + def testBasic(self): + with tf.Graph().as_default(): + a, b, c, d = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oif' } + node { name: 'B' op: 'Otl' + attr { key: 't' + value { list { type: DT_INT32 type: DT_FLOAT } } } } + node { name: 'C' op: 'In' + attr { key: 'N' value { i: 2 } } + attr { key: 'T' value { type: DT_INT32 } } + input: 'A:0' input: 'B:0' } + node { name: 'D' op: 'In' + attr { key: 'N' value { i: 2 } } + attr { key: 'T' value { type: DT_FLOAT } } + input: 'A:1' input: 'B:1' } + """), + return_elements=['A', 'B', 'C', 'D'], + name='import') + + # Assert that the import process creates distinct tensors. + self.assertNotEqual(a.outputs[0].name, a.outputs[1].name) + self.assertNotEqual(b.outputs[0].name, b.outputs[1].name) + self.assertNotEqual(a.outputs[0].name, b.outputs[0].name) + self.assertNotEqual(a.outputs[0].name, b.outputs[1].name) + self.assertNotEqual(a.outputs[1].name, b.outputs[0].name) + self.assertNotEqual(a.outputs[1].name, b.outputs[1].name) + + # Assert that the ops are connected according to the GraphDef topology. + self.assertEqual(c.inputs[0], a.outputs[0]) + self.assertEqual(c.inputs[1], b.outputs[0]) + self.assertEqual(d.inputs[0], a.outputs[1]) + self.assertEqual(d.inputs[1], b.outputs[1]) + + # Check the types of the returned ops and tensors. + self.assertEqual(a.type, 'Oif') + self.assertEqual(b.type, 'Otl') + self.assertEqual(c.type, 'In') + self.assertEqual(d.type, 'In') + self.assertEqual(a.outputs[0].dtype, tf.int32) + self.assertEqual(a.outputs[1].dtype, tf.float32) + self.assertEqual(b.outputs[0].dtype, tf.int32) + self.assertEqual(b.outputs[1].dtype, tf.float32) + + # Check the names of the returned ops. + self.assertEqual(a.name, 'import/A') + self.assertEqual(b.name, 'import/B') + self.assertEqual(c.name, 'import/C') + self.assertEqual(d.name, 'import/D') + + def testInputMap(self): + with tf.Graph().as_default(): + feed_a_0 = tf.constant(0, dtype=tf.int32) + feed_b_1 = tf.constant(1, dtype=tf.int32) + + a, b, c, d = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oii' } + node { name: 'B' op: 'Oii' } + node { name: 'C' op: 'In' + attr { key: 'N' value { i: 2 } } + attr { key: 'T' value { type: DT_INT32 } } + input: 'A:0' input: 'B:0' } + node { name: 'D' op: 'In' + attr { key: 'N' value { i: 2 } } + attr { key: 'T' value { type: DT_INT32 } } + input: 'A:1' input: 'B:1' } + """), + input_map={'A:0': feed_a_0, 'B:1': feed_b_1}, + return_elements=['A', 'B', 'C', 'D']) + + self.assertEqual(c.inputs[0], feed_a_0) + self.assertEqual(c.inputs[1], b.outputs[0]) + self.assertEqual(d.inputs[0], a.outputs[1]) + self.assertEqual(d.inputs[1], feed_b_1) + + def testImplicitZerothOutput(self): + with tf.Graph().as_default(): + a, b = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oii' } + node { name: 'B' op: 'Ii' input: 'A' } + """), + return_elements=['A', 'B']) + + self.assertEqual(b.inputs[0], a.outputs[0]) + + def testInputMapImplicitZerothOutput(self): + with tf.Graph().as_default(): + feed_a_0 = tf.constant(0, dtype=tf.int32) + b, = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oii' } + node { name: 'B' op: 'Ii' input: 'A:0' } + """), + input_map={'A': feed_a_0}, + return_elements=['B']) + + self.assertEqual(b.inputs[0], feed_a_0) + + def testWithControlDependency(self): + with tf.Graph().as_default(): + a, b = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'None' } + node { name: 'B' op: 'None' input: '^A' } + """), + return_elements=['A', 'B']) + + self.assertEqual(b.control_inputs, [a]) + + def testWithRefs(self): + with tf.Graph().as_default(): + a, b, c, d = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Or' } + node { name: 'B' op: 'Oi' } + node { name: 'C' op: 'Iii' input: 'A:0' input: 'B:0' } + node { name: 'D' op: 'Iri' input: 'A:0' input: 'B:0' } + """), + return_elements=['A', 'B', 'C', 'D']) + + self.assertEqual(c.inputs[0], a.outputs[0]) + self.assertEqual(c.inputs[1], b.outputs[0]) + self.assertEqual(d.inputs[0], a.outputs[0]) + self.assertEqual(d.inputs[1], b.outputs[0]) + + self.assertEqual(a.outputs[0].dtype, tf.int32_ref) + self.assertEqual(c._input_dtypes, [tf.int32, tf.int32]) + self.assertEqual(c.outputs, []) + self.assertEqual(d._input_dtypes, + [tf.int32_ref, tf.int32]) + self.assertEqual(d.outputs, []) + + def testCyclic(self): + with tf.Graph().as_default(): + a, b = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Unary' + attr { key: 'T' value { type: DT_INT32 } } input: 'B:0' } + node { name: 'B' op: 'Unary' + attr { key: 'T' value { type: DT_INT32 } } input: 'A:0' } + """), + return_elements=['A', 'B']) + + self.assertEqual(a.inputs[0], b.outputs[0]) + self.assertEqual(b.inputs[0], a.outputs[0]) + + def testTypeMismatchInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + node { name: 'B' op: 'If' input: 'A:0' } + """)) + self.assertTrue( + 'Cannot convert a tensor of type int32 to an input of type float' in + str(e.exception)) + + def testInvalidSignatureTooManyInputsInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + node { name: 'B' op: 'None' input: 'A:0' } + """)) + self.assertTrue('More inputs specified (u\'A:0\') than the op expects' in + str(e.exception)) + + def testInvalidSignatureNotEnoughInputsInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + node { name: 'B' op: 'Iif' input: 'A:0' } + """)) + self.assertTrue('Input types mismatch (expected \'int32, float32\' but ' + 'got \'int32\')' in str(e.exception)) + + def testMissingInputOpInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'B' op: 'If' input: 'A:0' } + """)) + self.assertTrue('Input tensor %r not found' % (u'A:0',) in + str(e.exception)) + + def testMissingInputOpInGraphDefButAppearsInInputMap(self): + with tf.Graph().as_default(): + feed_a_0 = tf.constant(5.0) + b, = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'B' op: 'If' input: 'A:0' } + """), + input_map={'A:0': feed_a_0}, + return_elements=['B']) + self.assertEqual(b.inputs[0], feed_a_0) + + def testMissingInputTensorInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Of' } + node { name: 'B' op: 'If' input: 'A:1' } + """)) + self.assertTrue('Input tensor %r not found' % (u'A:1',) in + str(e.exception)) + + def testMissingControlInputInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'B' op: 'None' input: '^A' } + """)) + self.assertTrue('Control input %r not found' % (u'^A',) in + str(e.exception)) + + def testInvalidTensorNameOutputIndexInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'B' op: 'None' input: 'A:B' } + """)) + self.assertEqual( + 'Cannot convert %r to a tensor name.' % (u'A:B',), str(e.exception)) + + def testInvalidTensorNameInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'B' op: 'None' input: 'A:B:0' } + """)) + self.assertEqual( + 'Cannot convert %r to a tensor name.' % (u'A:B:0',), str(e.exception)) + + def testMissingReturnOperation(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'None' } + """), + return_elements=['B']) + self.assertTrue('return_element %r not found in graph_def.' % ('B') in + str(e.exception)) + + def testMissingReturnTensor(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + """), + return_elements=['A:1']) + self.assertTrue('return_element %r not found in graph_def.' % ('A:1') in + str(e.exception)) + + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + """), + return_elements=['B:0']) + self.assertTrue('return_element %r not found in graph_def.' % ('B:0') in + str(e.exception)) + + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + """), + return_elements=['A:B:0']) + self.assertTrue('return_element %r not found in graph_def.' % ('A:B:0') in + str(e.exception)) + + def testMissingInputMap(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'None' } + """), + input_map={'B:0': tf.constant(5.0)}) + self.assertTrue('not found in graph_def: [B:0]' in str(e.exception)) + + def testInputMapTypeMismatch(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + node { name: 'B' op: 'Ii' input: 'A:0' } + """), + input_map={'A:0': tf.constant(5.0)}) + self.assertTrue( + 'Cannot convert a tensor of type float32 to an input of type int32.' + in str(e.exception)) + + def testNoReturns(self): + with tf.Graph().as_default() as g: + ret = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'None' } + """)) + self.assertEqual(ret, None) + + a = g.get_operation_by_name('import/A') + self.assertEqual(a.type, 'None') + + def testOverrideNamePrefix(self): + with tf.Graph().as_default(): + a, = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'None' } + """), + return_elements=['A'], name='imported_graph') + self.assertEqual(a.name, 'imported_graph/A') + + def testEmptyGraph(self): + with tf.Graph().as_default() as g: + init_version = g.version + tf.import_graph_def(self._MakeGraphDef('')) + self.assertEqual(init_version, g.version) + + def testInvalidInputForGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(TypeError) as e: + tf.import_graph_def('') + self.assertEqual( + 'graph_def must be a GraphDef proto.', str(e.exception)) + + def testInvalidInputForInputMap(self): + with tf.Graph().as_default(): + with self.assertRaises(TypeError) as e: + tf.import_graph_def(self._MakeGraphDef(''), + input_map=[tf.constant(5.0)]) + self.assertEqual('input_map must be a dictionary mapping strings to ' + 'Tensor objects.', str(e.exception)) + + def testInvalidInputForReturnOperations(self): + with tf.Graph().as_default(): + with self.assertRaises(TypeError) as e: + tf.import_graph_def(self._MakeGraphDef(''), return_elements=[7]) + self.assertEqual( + 'return_elements must be a list of strings.', str(e.exception)) + + def testWithExtensionAndAttr(self): + with tf.Graph().as_default() as g: + c = tf.constant(5.0, dtype=tf.float32, name='c') + tf.pack([c, c], name='pack') + gdef = g.as_graph_def() + + with self.test_session(): + pack, = tf.import_graph_def(gdef, return_elements=['pack']) + self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0]) + + def testWithDevice(self): + with tf.Graph().as_default() as g: + # No device. + a = tf.constant(3.0, name='a') + + with tf.device('/cpu:0'): + b = tf.constant(4.0, name='b') + with tf.device('/job:worker'): + c = tf.constant(5.0, name='c') + + gdef = g.as_graph_def() + + with tf.Graph().as_default(): + a2, b2, c2 = tf.import_graph_def( + gdef, return_elements=['a', 'b', 'c']) + self.assertEqual(a.device, a2.device) + self.assertEqual(b.device, b2.device) + self.assertEqual(c.device, c2.device) + + with tf.Graph().as_default(): + with tf.device(device.merge_device('/task:0')): + a3, b3, c3 = tf.import_graph_def( + gdef, return_elements=['a', 'b', 'c']) + self.assertEqual('/task:0', a3.device) + self.assertEqual('/task:0/device:CPU:0', b3.device) # canonicalized. + self.assertEqual(c.device + '/task:0', c3.device) + + with tf.Graph().as_default(): + with tf.device(device.merge_device('/job:ps')): + a4, b4, c4 = tf.import_graph_def( + gdef, return_elements=['a', 'b', 'c']) + self.assertEqual('/job:ps', a4.device) + self.assertEqual('/job:ps/device:CPU:0', b4.device) # canonicalized. + self.assertEqual(c.device, c4.device) # worker overrides ps. + + with tf.Graph().as_default(): + with tf.device(device.merge_device('/gpu:0')): + a5, b5, c5 = tf.import_graph_def( + gdef, return_elements=['a', 'b', 'c']) + self.assertEqual('/device:GPU:0', a5.device) + self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu. + self.assertEqual(c.device + '/device:GPU:0', c5.device) + + def testGradient(self): + with tf.Graph().as_default() as g: + inputs = tf.placeholder(tf.float32, shape=[None, 100], name="input") + weights = tf.placeholder(tf.float32, shape=[100, 10], name="weights") + biases = tf.placeholder(tf.float32, shape=[10], name="biases") + activations = tf.nn.relu(tf.matmul(inputs, weights) + biases, + name="activations") + loss = tf.reduce_mean(activations, name="loss") + gdef = g.as_graph_def() + + with tf.Graph().as_default() as g: + input_placeholder = tf.placeholder(tf.float32, shape=[32, 100]) + weights_var = tf.Variable(tf.truncated_normal([100, 10]), name="weights") + biases_var = tf.Variable(tf.zeros(10), name="biases") + activations, loss = tf.import_graph_def( + gdef, + input_map={"input:0": input_placeholder, + "weights:0": weights_var, + "biases:0": biases_var}, + return_elements=["activations:0", "loss:0"]) + self.assertEqual([32, 10], activations.get_shape()) + self.assertEqual([], loss.get_shape()) + weights_grad, biases_grad = tf.gradients(loss, [weights_var, biases_var]) + self.assertEqual([100, 10], weights_grad.get_shape()) + self.assertEqual([10], biases_grad.get_shape()) + + def testLargeGraph(self): + with self.test_session(): + # The default message byte limit is 64M. Ours is 2G with a warning at 512. + # Adding a 150M entries float32 tensor should blow through the warning, + # but not the hard limit. + input_shape = [150, 1024, 1024] + tensor_input = tf.np.random.rand(*input_shape).astype(tf.np.float32) + t = tf.constant(tensor_input, shape=input_shape) + g = tf.identity(t) + g.eval() + + +if __name__ == '__main__': + tf.test.main() |