aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-11 16:10:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-11 16:14:04 -0700
commit8e82134acd68c313f0ecadb168a66916006b17a6 (patch)
treef22503d70d58f47b1e23092ab56bfec8abdd6928
parentcbe1ef05fff2d4fbd0ecc8f3a2c3a3a0f0dc312d (diff)
Use C API to implement Operation._input_types
This change first converts _input_types into a property and renames the member to _input_types_val. We keep _input_dtypes as an alias for _input_types as it was before this change. Similarly to _output_types, we can't enable normal tests yet. Instead, we add a simple temporary test for _input_types. Also, fix two minor typos in doc strings of function.py PiperOrigin-RevId: 161597185
-rw-r--r--tensorflow/python/framework/function.py4
-rw-r--r--tensorflow/python/framework/ops.py35
-rw-r--r--tensorflow/python/framework/ops_test.py43
3 files changed, 76 insertions, 6 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index dbd406ebd5..ff47c0dbf8 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -581,10 +581,10 @@ class _OverloadedFunction(object):
class _FuncGraph(ops.Graph):
- """A helper for construction a function.
+ """A helper for constructing a function.
_FuncGraph overrides ops.Graph's create_op() so that we can keep
- track of every inputs into every op created inside the function. If
+ track of all inputs into every op created inside the function. If
any input is from other graphs, we keep track of it in self.capture
and substitue the input with a place holder.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 638ed7ef4b..0ab5fdb390 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1183,7 +1183,7 @@ class Operation(object):
self.node_def.name,
[i.dtype for i in self._inputs],
input_types))
- self._input_types = input_types
+ self._input_types_val = input_types
# Build the list of control inputs.
self._control_inputs = []
@@ -1416,7 +1416,14 @@ class Operation(object):
tf_output.index = output_idx
return tf_output
- def _set_device(self, device):
+ def _tf_input(self, input_idx):
+ """Create and return a new TF_Input for input_idx'th input of this op."""
+ tf_input = c_api.TF_Input()
+ tf_input.oper = self._c_op
+ tf_input.index = input_idx
+ return tf_input
+
+ def _set_device(self, device): # pylint: disable=redefined-outer-name
"""Set the device of this operation.
Args:
@@ -1438,6 +1445,7 @@ class Operation(object):
or if input tensor type is not convertible to dtype.
ValueError: if the Tensor is from a different graph.
"""
+ assert not _USE_C_API, "Operation._add_input doesn't work with C API"
if not isinstance(tensor, Tensor):
raise TypeError("tensor must be a Tensor: %s" % tensor)
_assert_same_graph(self, tensor)
@@ -1450,7 +1458,7 @@ class Operation(object):
"Cannot convert a tensor of type %s to an input of type %s"
% (tensor.dtype.name, dtype.name))
self._inputs.append(tensor)
- self._input_types.append(dtype)
+ self._input_types_val.append(dtype)
tensor._add_consumer(self) # pylint: disable=protected-access
self._recompute_node_def()
@@ -1470,6 +1478,7 @@ class Operation(object):
or if input tensor type is not convertible to dtype.
ValueError: if the Tensor is from a different graph.
"""
+ assert not _USE_C_API, "Operation._update_input doesn't work with C API"
if not isinstance(tensor, Tensor):
raise TypeError("tensor must be a Tensor: %s" % tensor)
_assert_same_graph(self, tensor)
@@ -1484,7 +1493,7 @@ class Operation(object):
self._inputs[index].consumers().remove(self)
self._inputs[index] = tensor
- self._input_types[index] = dtype
+ self._input_types_val[index] = dtype
tensor._add_consumer(self) # pylint: disable=protected-access
self._recompute_node_def()
@@ -1498,6 +1507,8 @@ class Operation(object):
TypeError: if ops is not a list of Operations.
ValueError: if any op in ops is from a different graph.
"""
+ assert not _USE_C_API, (
+ "Operation._add_control_inputs doesn't work with C API")
if ops:
for op in ops:
if not isinstance(op, Operation):
@@ -1516,6 +1527,8 @@ class Operation(object):
TypeError: if op is not an Operation.
ValueError: if op is from a different graph.
"""
+ assert not _USE_C_API, (
+ "Operation._add_control_input doesn't work with C API")
self._add_control_inputs([op])
# Methods below are used when building the NodeDef and Graph proto.
@@ -1570,6 +1583,20 @@ class Operation(object):
return self._input_types
@property
+ def _input_types(self):
+ if _USE_C_API:
+ num_inputs = c_api.TF_OperationNumInputs(self._c_op)
+ input_types = [dtypes.as_dtype(
+ c_api.TF_OperationInputType(self._tf_input(i)))
+ for i in xrange(num_inputs)]
+ # TODO(iga): Remove this assert after converting to C API by default.
+ # Just being a bit paranoid here.
+ assert self._input_types_val == input_types
+ return input_types
+ else:
+ return self._input_types_val
+
+ @property
def control_inputs(self):
"""The `Operation` objects on which this op has a control dependency.
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 01841e6306..0ca9ad1d5f 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -1828,5 +1828,48 @@ class OutputTypesTest(test_util.TensorFlowTestCase):
# pylint: enable=protected-access
+class InputTypesTest(test_util.TensorFlowTestCase):
+ """Tests Operation._input_dtypes and Operation._input_types properties.
+
+ This test should not exist as _input_types is a private property.
+ This property is used by many tests that would normally cover its
+ behavior. However, we can't yet run these tests in C
+ API mode because they use _set_device method. This test will be deleted
+ once we port _set_device.
+ """
+ # TODO(iga): Remove this test
+
+ def setUp(self):
+ self.prev_use_c_api = ops._USE_C_API # pylint: disable=protected-access
+ ops._USE_C_API = True # pylint: disable=protected-access
+
+ def tearDown(self):
+ ops._USE_C_API = self.prev_use_c_api # pylint: disable=protected-access
+
+ def testZeroInputs(self):
+ g = ops.Graph()
+ with g.as_default():
+ # Using a constant because creating unregistered ops
+ # doesn't work with the C API.
+ op = constant_op.constant(12, dtype=dtypes.uint16).op
+ # pylint: disable=protected-access
+ self.assertEqual([], op._input_types)
+ self.assertEqual([], op._input_dtypes)
+ # pylint: enable=protected-access
+
+ def testTwoInputs(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = constant_op.constant(1.0, dtype=dtypes.double)
+ y = constant_op.constant(2.0, dtype=dtypes.double)
+ z = math_ops.multiply(x, y)
+ # pylint: disable=protected-access
+ self.assertTrue(isinstance(z.op._input_types[0], dtypes.DType))
+ self.assertTrue(isinstance(z.op._input_types[1], dtypes.DType))
+ self.assertEqual([dtypes.double, dtypes.double], z.op._input_types)
+ self.assertEqual([dtypes.double, dtypes.double], z.op._input_dtypes)
+ # pylint: enable=protected-access
+
+
if __name__ == "__main__":
googletest.main()