diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-11 16:10:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-11 16:14:04 -0700 |
commit | 8e82134acd68c313f0ecadb168a66916006b17a6 (patch) | |
tree | f22503d70d58f47b1e23092ab56bfec8abdd6928 | |
parent | cbe1ef05fff2d4fbd0ecc8f3a2c3a3a0f0dc312d (diff) |
Use C API to implement Operation._input_types
This change first converts _input_types into a property and renames the
member to _input_types_val. We keep _input_dtypes as an alias for
_input_types as it was before this change.
Similarly to _output_types, we can't enable normal tests yet. Instead,
we add a simple temporary test for _input_types.
Also, fix two minor typos in doc strings of function.py
PiperOrigin-RevId: 161597185
-rw-r--r-- | tensorflow/python/framework/function.py | 4 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 35 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 43 |
3 files changed, 76 insertions, 6 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index dbd406ebd5..ff47c0dbf8 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -581,10 +581,10 @@ class _OverloadedFunction(object): class _FuncGraph(ops.Graph): - """A helper for construction a function. + """A helper for constructing a function. _FuncGraph overrides ops.Graph's create_op() so that we can keep - track of every inputs into every op created inside the function. If + track of all inputs into every op created inside the function. If any input is from other graphs, we keep track of it in self.capture and substitue the input with a place holder. diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 638ed7ef4b..0ab5fdb390 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1183,7 +1183,7 @@ class Operation(object): self.node_def.name, [i.dtype for i in self._inputs], input_types)) - self._input_types = input_types + self._input_types_val = input_types # Build the list of control inputs. self._control_inputs = [] @@ -1416,7 +1416,14 @@ class Operation(object): tf_output.index = output_idx return tf_output - def _set_device(self, device): + def _tf_input(self, input_idx): + """Create and return a new TF_Input for input_idx'th input of this op.""" + tf_input = c_api.TF_Input() + tf_input.oper = self._c_op + tf_input.index = input_idx + return tf_input + + def _set_device(self, device): # pylint: disable=redefined-outer-name """Set the device of this operation. Args: @@ -1438,6 +1445,7 @@ class Operation(object): or if input tensor type is not convertible to dtype. ValueError: if the Tensor is from a different graph. """ + assert not _USE_C_API, "Operation._add_input doesn't work with C API" if not isinstance(tensor, Tensor): raise TypeError("tensor must be a Tensor: %s" % tensor) _assert_same_graph(self, tensor) @@ -1450,7 +1458,7 @@ class Operation(object): "Cannot convert a tensor of type %s to an input of type %s" % (tensor.dtype.name, dtype.name)) self._inputs.append(tensor) - self._input_types.append(dtype) + self._input_types_val.append(dtype) tensor._add_consumer(self) # pylint: disable=protected-access self._recompute_node_def() @@ -1470,6 +1478,7 @@ class Operation(object): or if input tensor type is not convertible to dtype. ValueError: if the Tensor is from a different graph. """ + assert not _USE_C_API, "Operation._update_input doesn't work with C API" if not isinstance(tensor, Tensor): raise TypeError("tensor must be a Tensor: %s" % tensor) _assert_same_graph(self, tensor) @@ -1484,7 +1493,7 @@ class Operation(object): self._inputs[index].consumers().remove(self) self._inputs[index] = tensor - self._input_types[index] = dtype + self._input_types_val[index] = dtype tensor._add_consumer(self) # pylint: disable=protected-access self._recompute_node_def() @@ -1498,6 +1507,8 @@ class Operation(object): TypeError: if ops is not a list of Operations. ValueError: if any op in ops is from a different graph. """ + assert not _USE_C_API, ( + "Operation._add_control_inputs doesn't work with C API") if ops: for op in ops: if not isinstance(op, Operation): @@ -1516,6 +1527,8 @@ class Operation(object): TypeError: if op is not an Operation. ValueError: if op is from a different graph. """ + assert not _USE_C_API, ( + "Operation._add_control_input doesn't work with C API") self._add_control_inputs([op]) # Methods below are used when building the NodeDef and Graph proto. @@ -1570,6 +1583,20 @@ class Operation(object): return self._input_types @property + def _input_types(self): + if _USE_C_API: + num_inputs = c_api.TF_OperationNumInputs(self._c_op) + input_types = [dtypes.as_dtype( + c_api.TF_OperationInputType(self._tf_input(i))) + for i in xrange(num_inputs)] + # TODO(iga): Remove this assert after converting to C API by default. + # Just being a bit paranoid here. + assert self._input_types_val == input_types + return input_types + else: + return self._input_types_val + + @property def control_inputs(self): """The `Operation` objects on which this op has a control dependency. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 01841e6306..0ca9ad1d5f 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -1828,5 +1828,48 @@ class OutputTypesTest(test_util.TensorFlowTestCase): # pylint: enable=protected-access +class InputTypesTest(test_util.TensorFlowTestCase): + """Tests Operation._input_dtypes and Operation._input_types properties. + + This test should not exist as _input_types is a private property. + This property is used by many tests that would normally cover its + behavior. However, we can't yet run these tests in C + API mode because they use _set_device method. This test will be deleted + once we port _set_device. + """ + # TODO(iga): Remove this test + + def setUp(self): + self.prev_use_c_api = ops._USE_C_API # pylint: disable=protected-access + ops._USE_C_API = True # pylint: disable=protected-access + + def tearDown(self): + ops._USE_C_API = self.prev_use_c_api # pylint: disable=protected-access + + def testZeroInputs(self): + g = ops.Graph() + with g.as_default(): + # Using a constant because creating unregistered ops + # doesn't work with the C API. + op = constant_op.constant(12, dtype=dtypes.uint16).op + # pylint: disable=protected-access + self.assertEqual([], op._input_types) + self.assertEqual([], op._input_dtypes) + # pylint: enable=protected-access + + def testTwoInputs(self): + g = ops.Graph() + with g.as_default(): + x = constant_op.constant(1.0, dtype=dtypes.double) + y = constant_op.constant(2.0, dtype=dtypes.double) + z = math_ops.multiply(x, y) + # pylint: disable=protected-access + self.assertTrue(isinstance(z.op._input_types[0], dtypes.DType)) + self.assertTrue(isinstance(z.op._input_types[1], dtypes.DType)) + self.assertEqual([dtypes.double, dtypes.double], z.op._input_types) + self.assertEqual([dtypes.double, dtypes.double], z.op._input_dtypes) + # pylint: enable=protected-access + + if __name__ == "__main__": googletest.main() |