aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/importer_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/importer_test.py')
-rw-r--r--tensorflow/python/framework/importer_test.py546
1 files changed, 546 insertions, 0 deletions
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
new file mode 100644
index 0000000000..470092313a
--- /dev/null
+++ b/tensorflow/python/framework/importer_test.py
@@ -0,0 +1,546 @@
+"""Tests for tensorflow.python.framework.importer."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.python.framework import device
+from tensorflow.python.framework import op_def_registry
+
+
+_op_list = op_def_pb2.OpList()
+text_format.Merge("""
+ op {
+ name: 'None'
+ }
+ op {
+ name: 'Oi'
+ output_arg { name: 'a' type: DT_INT32 }
+ }
+ op {
+ name: 'Or'
+ output_arg { name: 'a' type: DT_INT32 is_ref: true }
+ }
+ op {
+ name: 'Of'
+ output_arg { name: 'a' type: DT_FLOAT }
+ }
+ op {
+ name: 'Ii'
+ input_arg { name: 'a' type: DT_INT32 }
+ }
+ op {
+ name: 'If'
+ input_arg { name: 'a' type: DT_FLOAT }
+ }
+ op {
+ name: 'Oii'
+ output_arg { name: 'a' type: DT_INT32 }
+ output_arg { name: 'b' type: DT_INT32 }
+ }
+ op {
+ name: 'Oif'
+ output_arg { name: 'a' type: DT_INT32 }
+ output_arg { name: 'b' type: DT_FLOAT }
+ }
+ op {
+ name: 'Iii'
+ input_arg { name: 'a' type: DT_INT32 }
+ input_arg { name: 'b' type: DT_INT32 }
+ }
+ op {
+ name: 'Iff'
+ input_arg { name: 'a' type: DT_FLOAT }
+ input_arg { name: 'b' type: DT_FLOAT }
+ }
+ op {
+ name: 'Iif'
+ input_arg { name: 'a' type: DT_INT32 }
+ input_arg { name: 'b' type: DT_FLOAT }
+ }
+ op {
+ name: 'Iri'
+ input_arg { name: 'a' type: DT_INT32 is_ref: true }
+ input_arg { name: 'b' type: DT_INT32 }
+ }
+ op {
+ name: 'In'
+ input_arg { name: 'a' number_attr: 'N' type_attr: 'T' }
+ attr { name: 'N' type: 'int' minimum: 1 }
+ attr { name: 'T' type: 'type' }
+ }
+ op {
+ name: 'Otl'
+ output_arg { name: 'a' type_list_attr: 't' }
+ attr { name: 'T' type: 'list(type)' minimum: 1 }
+ }
+ op {
+ name: 'Unary'
+ input_arg { name: 'a' type_attr: 'T' }
+ output_arg { name: 'b' type_attr: 'T' }
+ attr { name: 'T' type: 'type' }
+ }
+""", _op_list)
+op_def_registry.register_op_list(_op_list)
+# NOTE(mrry): Dummy shape registrations for ops used in the tests.
+for op_def in _op_list.op:
+ tf.RegisterShape(op_def.name)(None)
+
+class ImportGraphDefTest(tf.test.TestCase):
+
+ def _MakeGraphDef(self, text):
+ ret = tf.GraphDef()
+ text_format.Merge(text, ret)
+ return ret
+
+ def testBasic(self):
+ with tf.Graph().as_default():
+ a, b, c, d = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oif' }
+ node { name: 'B' op: 'Otl'
+ attr { key: 't'
+ value { list { type: DT_INT32 type: DT_FLOAT } } } }
+ node { name: 'C' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ input: 'A:0' input: 'B:0' }
+ node { name: 'D' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_FLOAT } }
+ input: 'A:1' input: 'B:1' }
+ """),
+ return_elements=['A', 'B', 'C', 'D'],
+ name='import')
+
+ # Assert that the import process creates distinct tensors.
+ self.assertNotEqual(a.outputs[0].name, a.outputs[1].name)
+ self.assertNotEqual(b.outputs[0].name, b.outputs[1].name)
+ self.assertNotEqual(a.outputs[0].name, b.outputs[0].name)
+ self.assertNotEqual(a.outputs[0].name, b.outputs[1].name)
+ self.assertNotEqual(a.outputs[1].name, b.outputs[0].name)
+ self.assertNotEqual(a.outputs[1].name, b.outputs[1].name)
+
+ # Assert that the ops are connected according to the GraphDef topology.
+ self.assertEqual(c.inputs[0], a.outputs[0])
+ self.assertEqual(c.inputs[1], b.outputs[0])
+ self.assertEqual(d.inputs[0], a.outputs[1])
+ self.assertEqual(d.inputs[1], b.outputs[1])
+
+ # Check the types of the returned ops and tensors.
+ self.assertEqual(a.type, 'Oif')
+ self.assertEqual(b.type, 'Otl')
+ self.assertEqual(c.type, 'In')
+ self.assertEqual(d.type, 'In')
+ self.assertEqual(a.outputs[0].dtype, tf.int32)
+ self.assertEqual(a.outputs[1].dtype, tf.float32)
+ self.assertEqual(b.outputs[0].dtype, tf.int32)
+ self.assertEqual(b.outputs[1].dtype, tf.float32)
+
+ # Check the names of the returned ops.
+ self.assertEqual(a.name, 'import/A')
+ self.assertEqual(b.name, 'import/B')
+ self.assertEqual(c.name, 'import/C')
+ self.assertEqual(d.name, 'import/D')
+
+ def testInputMap(self):
+ with tf.Graph().as_default():
+ feed_a_0 = tf.constant(0, dtype=tf.int32)
+ feed_b_1 = tf.constant(1, dtype=tf.int32)
+
+ a, b, c, d = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oii' }
+ node { name: 'B' op: 'Oii' }
+ node { name: 'C' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ input: 'A:0' input: 'B:0' }
+ node { name: 'D' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ input: 'A:1' input: 'B:1' }
+ """),
+ input_map={'A:0': feed_a_0, 'B:1': feed_b_1},
+ return_elements=['A', 'B', 'C', 'D'])
+
+ self.assertEqual(c.inputs[0], feed_a_0)
+ self.assertEqual(c.inputs[1], b.outputs[0])
+ self.assertEqual(d.inputs[0], a.outputs[1])
+ self.assertEqual(d.inputs[1], feed_b_1)
+
+ def testImplicitZerothOutput(self):
+ with tf.Graph().as_default():
+ a, b = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oii' }
+ node { name: 'B' op: 'Ii' input: 'A' }
+ """),
+ return_elements=['A', 'B'])
+
+ self.assertEqual(b.inputs[0], a.outputs[0])
+
+ def testInputMapImplicitZerothOutput(self):
+ with tf.Graph().as_default():
+ feed_a_0 = tf.constant(0, dtype=tf.int32)
+ b, = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oii' }
+ node { name: 'B' op: 'Ii' input: 'A:0' }
+ """),
+ input_map={'A': feed_a_0},
+ return_elements=['B'])
+
+ self.assertEqual(b.inputs[0], feed_a_0)
+
+ def testWithControlDependency(self):
+ with tf.Graph().as_default():
+ a, b = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ node { name: 'B' op: 'None' input: '^A' }
+ """),
+ return_elements=['A', 'B'])
+
+ self.assertEqual(b.control_inputs, [a])
+
+ def testWithRefs(self):
+ with tf.Graph().as_default():
+ a, b, c, d = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Or' }
+ node { name: 'B' op: 'Oi' }
+ node { name: 'C' op: 'Iii' input: 'A:0' input: 'B:0' }
+ node { name: 'D' op: 'Iri' input: 'A:0' input: 'B:0' }
+ """),
+ return_elements=['A', 'B', 'C', 'D'])
+
+ self.assertEqual(c.inputs[0], a.outputs[0])
+ self.assertEqual(c.inputs[1], b.outputs[0])
+ self.assertEqual(d.inputs[0], a.outputs[0])
+ self.assertEqual(d.inputs[1], b.outputs[0])
+
+ self.assertEqual(a.outputs[0].dtype, tf.int32_ref)
+ self.assertEqual(c._input_dtypes, [tf.int32, tf.int32])
+ self.assertEqual(c.outputs, [])
+ self.assertEqual(d._input_dtypes,
+ [tf.int32_ref, tf.int32])
+ self.assertEqual(d.outputs, [])
+
+ def testCyclic(self):
+ with tf.Graph().as_default():
+ a, b = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Unary'
+ attr { key: 'T' value { type: DT_INT32 } } input: 'B:0' }
+ node { name: 'B' op: 'Unary'
+ attr { key: 'T' value { type: DT_INT32 } } input: 'A:0' }
+ """),
+ return_elements=['A', 'B'])
+
+ self.assertEqual(a.inputs[0], b.outputs[0])
+ self.assertEqual(b.inputs[0], a.outputs[0])
+
+ def testTypeMismatchInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'If' input: 'A:0' }
+ """))
+ self.assertTrue(
+ 'Cannot convert a tensor of type int32 to an input of type float' in
+ str(e.exception))
+
+ def testInvalidSignatureTooManyInputsInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'None' input: 'A:0' }
+ """))
+ self.assertTrue('More inputs specified (u\'A:0\') than the op expects' in
+ str(e.exception))
+
+ def testInvalidSignatureNotEnoughInputsInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'Iif' input: 'A:0' }
+ """))
+ self.assertTrue('Input types mismatch (expected \'int32, float32\' but '
+ 'got \'int32\')' in str(e.exception))
+
+ def testMissingInputOpInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'If' input: 'A:0' }
+ """))
+ self.assertTrue('Input tensor %r not found' % (u'A:0',) in
+ str(e.exception))
+
+ def testMissingInputOpInGraphDefButAppearsInInputMap(self):
+ with tf.Graph().as_default():
+ feed_a_0 = tf.constant(5.0)
+ b, = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'If' input: 'A:0' }
+ """),
+ input_map={'A:0': feed_a_0},
+ return_elements=['B'])
+ self.assertEqual(b.inputs[0], feed_a_0)
+
+ def testMissingInputTensorInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Of' }
+ node { name: 'B' op: 'If' input: 'A:1' }
+ """))
+ self.assertTrue('Input tensor %r not found' % (u'A:1',) in
+ str(e.exception))
+
+ def testMissingControlInputInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'None' input: '^A' }
+ """))
+ self.assertTrue('Control input %r not found' % (u'^A',) in
+ str(e.exception))
+
+ def testInvalidTensorNameOutputIndexInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'None' input: 'A:B' }
+ """))
+ self.assertEqual(
+ 'Cannot convert %r to a tensor name.' % (u'A:B',), str(e.exception))
+
+ def testInvalidTensorNameInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'None' input: 'A:B:0' }
+ """))
+ self.assertEqual(
+ 'Cannot convert %r to a tensor name.' % (u'A:B:0',), str(e.exception))
+
+ def testMissingReturnOperation(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """),
+ return_elements=['B'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('B') in
+ str(e.exception))
+
+ def testMissingReturnTensor(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ """),
+ return_elements=['A:1'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('A:1') in
+ str(e.exception))
+
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ """),
+ return_elements=['B:0'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('B:0') in
+ str(e.exception))
+
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ """),
+ return_elements=['A:B:0'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('A:B:0') in
+ str(e.exception))
+
+ def testMissingInputMap(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """),
+ input_map={'B:0': tf.constant(5.0)})
+ self.assertTrue('not found in graph_def: [B:0]' in str(e.exception))
+
+ def testInputMapTypeMismatch(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'Ii' input: 'A:0' }
+ """),
+ input_map={'A:0': tf.constant(5.0)})
+ self.assertTrue(
+ 'Cannot convert a tensor of type float32 to an input of type int32.'
+ in str(e.exception))
+
+ def testNoReturns(self):
+ with tf.Graph().as_default() as g:
+ ret = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """))
+ self.assertEqual(ret, None)
+
+ a = g.get_operation_by_name('import/A')
+ self.assertEqual(a.type, 'None')
+
+ def testOverrideNamePrefix(self):
+ with tf.Graph().as_default():
+ a, = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """),
+ return_elements=['A'], name='imported_graph')
+ self.assertEqual(a.name, 'imported_graph/A')
+
+ def testEmptyGraph(self):
+ with tf.Graph().as_default() as g:
+ init_version = g.version
+ tf.import_graph_def(self._MakeGraphDef(''))
+ self.assertEqual(init_version, g.version)
+
+ def testInvalidInputForGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(TypeError) as e:
+ tf.import_graph_def('')
+ self.assertEqual(
+ 'graph_def must be a GraphDef proto.', str(e.exception))
+
+ def testInvalidInputForInputMap(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(TypeError) as e:
+ tf.import_graph_def(self._MakeGraphDef(''),
+ input_map=[tf.constant(5.0)])
+ self.assertEqual('input_map must be a dictionary mapping strings to '
+ 'Tensor objects.', str(e.exception))
+
+ def testInvalidInputForReturnOperations(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(TypeError) as e:
+ tf.import_graph_def(self._MakeGraphDef(''), return_elements=[7])
+ self.assertEqual(
+ 'return_elements must be a list of strings.', str(e.exception))
+
+ def testWithExtensionAndAttr(self):
+ with tf.Graph().as_default() as g:
+ c = tf.constant(5.0, dtype=tf.float32, name='c')
+ tf.pack([c, c], name='pack')
+ gdef = g.as_graph_def()
+
+ with self.test_session():
+ pack, = tf.import_graph_def(gdef, return_elements=['pack'])
+ self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])
+
+ def testWithDevice(self):
+ with tf.Graph().as_default() as g:
+ # No device.
+ a = tf.constant(3.0, name='a')
+
+ with tf.device('/cpu:0'):
+ b = tf.constant(4.0, name='b')
+ with tf.device('/job:worker'):
+ c = tf.constant(5.0, name='c')
+
+ gdef = g.as_graph_def()
+
+ with tf.Graph().as_default():
+ a2, b2, c2 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual(a.device, a2.device)
+ self.assertEqual(b.device, b2.device)
+ self.assertEqual(c.device, c2.device)
+
+ with tf.Graph().as_default():
+ with tf.device(device.merge_device('/task:0')):
+ a3, b3, c3 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual('/task:0', a3.device)
+ self.assertEqual('/task:0/device:CPU:0', b3.device) # canonicalized.
+ self.assertEqual(c.device + '/task:0', c3.device)
+
+ with tf.Graph().as_default():
+ with tf.device(device.merge_device('/job:ps')):
+ a4, b4, c4 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual('/job:ps', a4.device)
+ self.assertEqual('/job:ps/device:CPU:0', b4.device) # canonicalized.
+ self.assertEqual(c.device, c4.device) # worker overrides ps.
+
+ with tf.Graph().as_default():
+ with tf.device(device.merge_device('/gpu:0')):
+ a5, b5, c5 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual('/device:GPU:0', a5.device)
+ self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu.
+ self.assertEqual(c.device + '/device:GPU:0', c5.device)
+
+ def testGradient(self):
+ with tf.Graph().as_default() as g:
+ inputs = tf.placeholder(tf.float32, shape=[None, 100], name="input")
+ weights = tf.placeholder(tf.float32, shape=[100, 10], name="weights")
+ biases = tf.placeholder(tf.float32, shape=[10], name="biases")
+ activations = tf.nn.relu(tf.matmul(inputs, weights) + biases,
+ name="activations")
+ loss = tf.reduce_mean(activations, name="loss")
+ gdef = g.as_graph_def()
+
+ with tf.Graph().as_default() as g:
+ input_placeholder = tf.placeholder(tf.float32, shape=[32, 100])
+ weights_var = tf.Variable(tf.truncated_normal([100, 10]), name="weights")
+ biases_var = tf.Variable(tf.zeros(10), name="biases")
+ activations, loss = tf.import_graph_def(
+ gdef,
+ input_map={"input:0": input_placeholder,
+ "weights:0": weights_var,
+ "biases:0": biases_var},
+ return_elements=["activations:0", "loss:0"])
+ self.assertEqual([32, 10], activations.get_shape())
+ self.assertEqual([], loss.get_shape())
+ weights_grad, biases_grad = tf.gradients(loss, [weights_var, biases_var])
+ self.assertEqual([100, 10], weights_grad.get_shape())
+ self.assertEqual([10], biases_grad.get_shape())
+
+ def testLargeGraph(self):
+ with self.test_session():
+ # The default message byte limit is 64M. Ours is 2G with a warning at 512.
+ # Adding a 150M entries float32 tensor should blow through the warning,
+ # but not the hard limit.
+ input_shape = [150, 1024, 1024]
+ tensor_input = tf.np.random.rand(*input_shape).astype(tf.np.float32)
+ t = tf.constant(tensor_input, shape=input_shape)
+ g = tf.identity(t)
+ g.eval()
+
+
+if __name__ == '__main__':
+ tf.test.main()