diff options
author | 2018-06-05 08:13:07 -0700 | |
---|---|---|
committer | 2018-06-05 08:15:52 -0700 | |
commit | 274f9510f68f237589df5c6a414e4b8e5ebcdba1 (patch) | |
tree | b779bc166e890b31c03f6f3e82987daf23070096 /tensorflow | |
parent | c0dc76a3994c743151404b1401599fefb9f37dd4 (diff) |
Remove _USE_C_API staging from ops.py.
PiperOrigin-RevId: 199298594
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/copy_graph/python/util/copy_elements.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 5 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 544 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 3 |
4 files changed, 160 insertions, 393 deletions
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 102bc460fd..a0dd3881a8 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -218,7 +218,6 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): new_control_inputs, input_types, new_original_op, op_def) #Use Graph's hidden methods to add the op - to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 592d37b432..026a3d1200 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -189,9 +189,6 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None): if op._original_op: op_._original_op = op._original_op - # Add op to the graph - info.graph_._add_op(op_) - return op_, op_.outputs @@ -492,7 +489,7 @@ class Transformer(object): t_ = info.transformed_ts[t] consumer_op_ = info.transformed_ops[consumer_op] t_index_ = list(consumer_op_.inputs).index(tmp_t_) - consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access + consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access def _connect_control_inputs(self, info): """Connect the previously copied ops.""" diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index eceea5276a..b2fd98f431 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -56,6 +56,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils from tensorflow.python.util import tf_contextlib +from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.tf_export import tf_export @@ -288,15 +289,8 @@ class Tensor(_TensorLike): self._value_index = value_index self._dtype = dtypes.as_dtype(dtype) - if _USE_C_API: - # This will be set by set_shape_and_handle_data_for_outputs. - self._shape_val = None - else: - # The Python code requires all tensors start with a shape to support shape - # inference on imported while loops. This isn't necessary with the C API - # enabled because the C API provides the shapes for imported nodes. - # TODO(skyewm): remove when _USE_C_API is removed. - self._shape_val = tensor_shape.unknown_shape() + # This will be set by self.shape(). + self._shape_val = None # List of operations that use this Tensor as input. We maintain this list # to easily navigate a computation graph. @@ -384,7 +378,6 @@ class Tensor(_TensorLike): if _USE_C_SHAPES: self._shape_val = self._c_api_shape() else: - assert _USE_C_API # Call set_shape_and_handle_data_for_outputs in topological order on all # ops that are needed to compute self.op's shape. We do this instead of # having set_shape_and_handle_data_for_outputs recursively call @@ -508,8 +501,6 @@ class Tensor(_TensorLike): else: self._shape_val = self.shape.merge_with(shape) - if not self._op._graph._c_graph: return - # Update C shape even if _USE_C_SHAPES = False, since we still want # set_shape to be reflected in the C API graph for when we run it. if not isinstance(shape, tensor_shape.TensorShape): @@ -545,33 +536,14 @@ class Tensor(_TensorLike): Returns: A list of `Operation`s. """ - if self._op._c_op: # pylint: disable=protected-access - consumer_names = c_api.TF_OperationOutputConsumers_wrapper( - self._as_tf_output()) - # pylint: disable=protected-access - return [ - self.graph._get_operation_by_name_unsafe(name) - for name in consumer_names - ] - # pylint: enable=protected-access - else: - return self._consumers - - def _add_consumer(self, consumer): - """Add a consumer to this tensor. - - Args: - consumer: an Operation. - - Raises: - TypeError: if the consumer is not an Operation. - """ + consumer_names = c_api.TF_OperationOutputConsumers_wrapper( + self._as_tf_output()) # pylint: disable=protected-access - assert not self._op._c_op, "Tensor._add_consumer doesn't work with C API" + return [ + self.graph._get_operation_by_name_unsafe(name) + for name in consumer_names + ] # pylint: enable=protected-access - if not isinstance(consumer, Operation): - raise TypeError("Consumer must be an Operation: %s" % consumer) - self._consumers.append(consumer) def _as_node_def_input(self): """Return a value to use for the NodeDef "input" attribute. @@ -594,7 +566,6 @@ class Tensor(_TensorLike): def _as_tf_output(self): # pylint: disable=protected-access - assert self.op._c_op return c_api_util.tf_output(self.op._c_op, self.value_index) # pylint: enable=protected-access @@ -1722,18 +1693,8 @@ class Operation(object): "a Tensor, or IndexedSlices: %s" % c) control_input_ops.append(control_op) - # Don't set private fields with C API enabled to catch users who need to - # switch to public API. - # TODO(skyewm): delete these fields once we remove _USE_C_API - if not self._graph._c_graph: - self._inputs_val = list(inputs) # Defensive copy. - self._input_types_val = input_types - self._control_inputs_val = control_input_ops - self._node_def_val = copy.deepcopy(node_def) - self._op_def_val = op_def - else: - # This will be set by self.inputs. - self._inputs_val = None + # This will be set by self.inputs. + self._inputs_val = None self._id_value = self._graph._next_id() # pylint: disable=protected-access self._original_op = original_op @@ -1742,10 +1703,8 @@ class Operation(object): # Initialize self._c_op. if c_op: - # TODO(skyewm): remove this assert when we remove USE_C_API - assert self._graph._c_graph # pylint: disable=protected-access self._c_op = c_op - elif self._graph._c_graph: # pylint: disable=protected-access + else: if op_def is None: op_def = self._graph._get_op_def(node_def.op) # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. @@ -1754,30 +1713,19 @@ class Operation(object): op_def, inputs, node_def.attr) self._c_op = _create_c_op(self._graph, node_def, grouped_inputs, control_input_ops) - else: - self._c_op = None - - # Mark that we consume the inputs. This is unnecessary and unsupported with - # the C API enabled, since the C API tracks the tensor consumers instead. - if not self._c_op: - for input_tensor in self._inputs_val: - input_tensor._add_consumer(self) # pylint: disable=protected-access # Initialize self._outputs. - if self._c_op: - num_outputs = c_api.TF_OperationNumOutputs(self._c_op) - output_types = [ - c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i)) - for i in range(num_outputs)] - assert output_types is not None - elif output_types is None: - output_types = [] - self._output_types_val = output_types + num_outputs = c_api.TF_OperationNumOutputs(self._c_op) + output_types = [ + c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i)) + for i in range(num_outputs)] self._outputs = [ Tensor(self, i, output_type) for i, output_type in enumerate(output_types) ] + self._graph._add_op(self) # pylint: disable=protected-access + if not c_op: self._control_flow_post_processing() @@ -1791,7 +1739,6 @@ class Operation(object): control_flow_util.CheckInputFromValidContext(self, input_tensor.op) if self._control_flow_context is not None: self._control_flow_context.AddOp(self) - self._recompute_node_def() def _reconstruct_sequence_inputs(self, op_def, inputs, attrs): """Regroups a flat list of input tensors into scalar and sequence inputs. @@ -1872,10 +1819,7 @@ class Operation(object): @property def name(self): """The full name of this operation.""" - if self._c_op: - return c_api.TF_OperationName(self._c_op) - else: - return self._node_def_val.name + return c_api.TF_OperationName(self._c_op) @property def _id(self): @@ -1891,10 +1835,7 @@ class Operation(object): assigned, or an empty string if it has not been assigned to a device. """ - if self._c_op: - return c_api.TF_OperationDevice(self._c_op) - else: - return self._node_def_val.device + return c_api.TF_OperationDevice(self._c_op) @property def _output_types(self): @@ -1907,28 +1848,21 @@ class Operation(object): The length of this list indicates the number of output endpoints of the operation. """ - if self._c_op: - num_outputs = c_api.TF_OperationNumOutputs(self._c_op) - output_types = [ - c_api.TF_OperationOutputType(self._tf_output(i)) - for i in xrange(num_outputs) - ] - # TODO(iga): Remove this assert after converting to C API by default. - # Just being a bit paranoid here. - assert self._output_types_val == output_types - # In all the tests we have output_types that are passed into - # Operation.__init__ are a list of ints (which is illegal according - # to the docstring), but input_types are instances of DType. - # This extra assert is to catch if we ever use DType for output_types. - if output_types: - assert isinstance(output_types[0], int) - return output_types - else: - return self._output_types_val + num_outputs = c_api.TF_OperationNumOutputs(self._c_op) + output_types = [ + c_api.TF_OperationOutputType(self._tf_output(i)) + for i in xrange(num_outputs) + ] + # In all the tests we have output_types that are passed into + # Operation.__init__ are a list of ints (which is illegal according + # to the docstring), but input_types are instances of DType. + # This extra assert is to catch if we ever use DType for output_types. + if output_types: + assert isinstance(output_types[0], int) + return output_types def _tf_output(self, output_idx): """Create and return a new TF_Output for output_idx'th output of this op.""" - assert self._c_op tf_output = c_api.TF_Output() tf_output.oper = self._c_op tf_output.index = output_idx @@ -1936,7 +1870,6 @@ class Operation(object): def _tf_input(self, input_idx): """Create and return a new TF_Input for input_idx'th input of this op.""" - assert self._c_op tf_input = c_api.TF_Input() tf_input.oper = self._c_op tf_input.index = input_idx @@ -1948,47 +1881,12 @@ class Operation(object): Args: device: string or device.. The device to set. """ - if self._c_op: - c_api.SetRequestedDevice( - self._graph._c_graph, # pylint: disable=protected-access - self._c_op, # pylint: disable=protected-access - compat.as_str(_device_string(device))) - else: - self._node_def_val.device = _device_string(device) - - def _add_input(self, tensor, dtype=None): - """Add a new input to this operation. - - Args: - tensor: the Tensor to add as an input. - dtype: tf.DType: type of the input; defaults to - the tensor's dtype. + c_api.SetRequestedDevice( + self._graph._c_graph, # pylint: disable=protected-access + self._c_op, # pylint: disable=protected-access + compat.as_str(_device_string(device))) - Raises: - TypeError: if tensor is not a Tensor, - or if input tensor type is not convertible to dtype. - ValueError: if the Tensor is from a different graph. - """ - assert not self._c_op, ( - "Operation._add_input doesn't work with C API") - if not isinstance(tensor, Tensor): - raise TypeError("tensor must be a Tensor: %s" % tensor) - _assert_same_graph(self, tensor) - if dtype is None: - dtype = tensor.dtype - else: - dtype = dtypes.as_dtype(dtype) - if not dtype.is_compatible_with(tensor.dtype): - raise TypeError( - "Cannot convert a tensor of type %s to an input of type %s" % - (tensor.dtype.name, dtype.name)) - self._inputs_val.append(tensor) - self._input_types_val.append(dtype) - tensor._add_consumer(self) # pylint: disable=protected-access - self._recompute_node_def() - - # TODO(skyewm): Remove `update_dtype` when we enable the C API. - def _update_input(self, index, tensor, update_dtype=True): + def _update_input(self, index, tensor): """Update the input to this operation at the given index. NOTE: This is for TF internal use only. Please don't use it. @@ -1996,7 +1894,6 @@ class Operation(object): Args: index: the index of the input to update. tensor: the Tensor to be used as the input at the given index. - update_dtype: If `False`, the type for this input is not updated. Raises: TypeError: if tensor is not a Tensor, @@ -2013,20 +1910,12 @@ class Operation(object): if not _USE_C_SHAPES: set_shape_and_handle_data_for_outputs(self) - if self._c_op: - # Reset cached inputs. - self._inputs_val = None - c_api.UpdateEdge( - self._graph._c_graph, # pylint: disable=protected-access - tensor._as_tf_output(), # pylint: disable=protected-access - self._tf_input(index)) - else: - self._inputs_val[index].consumers().remove(self) - self._inputs_val[index] = tensor - if update_dtype: - self._input_types_val[index] = tensor.dtype - tensor._add_consumer(self) # pylint: disable=protected-access - self._recompute_node_def() + # Reset cached inputs. + self._inputs_val = None + c_api.UpdateEdge( + self._graph._c_graph, # pylint: disable=protected-access + tensor._as_tf_output(), # pylint: disable=protected-access + self._tf_input(index)) def _add_control_inputs(self, ops): """Add a list of new control inputs to this operation. @@ -2038,19 +1927,10 @@ class Operation(object): TypeError: if ops is not a list of Operations. ValueError: if any op in ops is from a different graph. """ - if self._c_op: - for op in ops: - if not isinstance(op, Operation): - raise TypeError("op must be an Operation: %s" % op) - c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access - else: - if ops: - for op in ops: - if not isinstance(op, Operation): - raise TypeError("op must be an Operation: %s" % op) - _assert_same_graph(self, op) - self._control_inputs_val.append(op) - self._recompute_node_def() + for op in ops: + if not isinstance(op, Operation): + raise TypeError("op must be an Operation: %s" % op) + c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access def _add_control_input(self, op): """Add a new control input to this operation. @@ -2062,33 +1942,13 @@ class Operation(object): TypeError: if op is not an Operation. ValueError: if op is from a different graph. """ - if self._c_op: - if not isinstance(op, Operation): - raise TypeError("op must be an Operation: %s" % op) - c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access - else: - self._add_control_inputs([op]) + if not isinstance(op, Operation): + raise TypeError("op must be an Operation: %s" % op) + c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access def _remove_all_control_inputs(self): """Removes any control inputs to this operation.""" - if self._c_op: - c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access - else: - del self.control_inputs[:] - - # Methods below are used when building the NodeDef and Graph proto. - def _recompute_node_def(self): - # TODO(skyewm): remove this function when we switch to C API - if self._c_op: return - - del self._node_def_val.input[:] - # pylint: disable=protected-access - self._node_def_val.input.extend( - [t._as_node_def_input() for t in self._inputs_val]) - # pylint: enable=protected-access - if self._control_inputs_val: - self._node_def_val.input.extend( - ["^%s" % op.name for op in self._control_inputs_val]) + c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access def __str__(self): return str(self.node_def) @@ -2129,19 +1989,16 @@ class Operation(object): @property def inputs(self): """The list of `Tensor` objects representing the data inputs of this op.""" - if self._c_op: - if self._inputs_val is None: - tf_outputs = c_api.GetOperationInputs(self._c_op) - # pylint: disable=protected-access - retval = [ - self.graph._get_tensor_by_tf_output(tf_output) - for tf_output in tf_outputs - ] - # pylint: enable=protected-access - self._inputs_val = Operation._InputList(retval) - return self._inputs_val - else: - return Operation._InputList(self._inputs_val) + if self._inputs_val is None: + tf_outputs = c_api.GetOperationInputs(self._c_op) + # pylint: disable=protected-access + retval = [ + self.graph._get_tensor_by_tf_output(tf_output) + for tf_output in tf_outputs + ] + # pylint: enable=protected-access + self._inputs_val = Operation._InputList(retval) + return self._inputs_val @property def _inputs(self): @@ -2155,15 +2012,12 @@ class Operation(object): @property def _input_types(self): - if self._c_op: - num_inputs = c_api.TF_OperationNumInputs(self._c_op) - input_types = [ - dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i))) - for i in xrange(num_inputs) - ] - return input_types - else: - return self._input_types_val + num_inputs = c_api.TF_OperationNumInputs(self._c_op) + input_types = [ + dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i))) + for i in xrange(num_inputs) + ] + return input_types @_input_types.setter def _input_types(self, value): @@ -2183,16 +2037,13 @@ class Operation(object): A list of `Operation` objects. """ - if self._c_op: - control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) - # pylint: disable=protected-access - return [ - self.graph._get_operation_by_name_unsafe( - c_api.TF_OperationName(c_op)) for c_op in control_c_ops - ] - # pylint: enable=protected-access - else: - return self._control_inputs_val + control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) + # pylint: disable=protected-access + return [ + self.graph._get_operation_by_name_unsafe( + c_api.TF_OperationName(c_op)) for c_op in control_c_ops + ] + # pylint: enable=protected-access @property def _control_outputs(self): @@ -2205,18 +2056,13 @@ class Operation(object): A list of `Operation` objects. """ - if self._c_op: - control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op) - # pylint: disable=protected-access - return [ - self.graph._get_operation_by_name_unsafe( - c_api.TF_OperationName(c_op)) for c_op in control_c_ops - ] - # pylint: enable=protected-access - else: - # TODO(apassos) this should be less inefficient. - return [o for o in self._graph.get_operations() - if self in o.control_inputs] + control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op) + # pylint: disable=protected-access + return [ + self.graph._get_operation_by_name_unsafe( + c_api.TF_OperationName(c_op)) for c_op in control_c_ops + ] + # pylint: enable=protected-access @property def _control_inputs(self): @@ -2240,11 +2086,7 @@ class Operation(object): @property def type(self): """The type of the op (e.g. `"MatMul"`).""" - if self._c_op: - op_type = c_api.TF_OperationOpType(self._c_op) - return op_type - else: - return self._node_def_val.op + return c_api.TF_OperationOpType(self._c_op) @property def graph(self): @@ -2262,15 +2104,12 @@ class Operation(object): protocol buffer. """ # pylint: enable=line-too-long - if self._c_op: - with c_api_util.tf_buffer() as buf: - c_api.TF_OperationToNodeDef(self._c_op, buf) - data = c_api.TF_GetBuffer(buf) - node_def = node_def_pb2.NodeDef() - node_def.ParseFromString(compat.as_bytes(data)) - return node_def - else: - return self._node_def_val + with c_api_util.tf_buffer() as buf: + c_api.TF_OperationToNodeDef(self._c_op, buf) + data = c_api.TF_GetBuffer(buf) + node_def = node_def_pb2.NodeDef() + node_def.ParseFromString(compat.as_bytes(data)) + return node_def @property def _node_def(self): @@ -2289,10 +2128,7 @@ class Operation(object): protocol buffer. """ # pylint: enable=line-too-long - if self._c_op: - return self._graph._get_op_def(self.type) - else: - return self._op_def_val + return self._graph._get_op_def(self.type) @property def _op_def(self): @@ -2318,17 +2154,14 @@ class Operation(object): def _set_attr(self, attr_name, attr_value): """Private method used to set an attribute in the node_def.""" - if self._c_op: - buf = c_api.TF_NewBufferFromString( - compat.as_bytes(attr_value.SerializeToString())) - try: - # pylint: disable=protected-access - c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf) - # pylint: enable=protected-access - finally: - c_api.TF_DeleteBuffer(buf) - else: - self._node_def_val.attr[attr_name].CopyFrom(attr_value) + buf = c_api.TF_NewBufferFromString( + compat.as_bytes(attr_value.SerializeToString())) + try: + # pylint: disable=protected-access + c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf) + # pylint: enable=protected-access + finally: + c_api.TF_DeleteBuffer(buf) def get_attr(self, name): """Returns the value of the attr of this op with the given `name`. @@ -2343,21 +2176,15 @@ class Operation(object): ValueError: If this op does not have an attr with the given `name`. """ fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] - if self._c_op: - try: - with c_api_util.tf_buffer() as buf: - c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf) - data = c_api.TF_GetBuffer(buf) - except errors.InvalidArgumentError as e: - # Convert to ValueError for backwards compatibility. - raise ValueError(str(e)) - x = attr_value_pb2.AttrValue() - x.ParseFromString(data) - else: - if name not in self._node_def_val.attr: - raise ValueError( - "No attr named '" + name + "' in " + str(self._node_def_val)) - x = self._node_def_val.attr[name] + try: + with c_api_util.tf_buffer() as buf: + c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf) + data = c_api.TF_GetBuffer(buf) + except errors.InvalidArgumentError as e: + # Convert to ValueError for backwards compatibility. + raise ValueError(str(e)) + x = attr_value_pb2.AttrValue() + x.ParseFromString(data) # Treat an empty oneof value as an empty list. if not x.WhichOneof("value"): @@ -2577,9 +2404,9 @@ def _set_shape_and_handle_data_for_outputs_c_api(op): def set_shape_and_handle_data_for_outputs(op): """Set the shapes and resource handle data for op's outputs. - When _USE_C_API = True, this is lazily called when a tensor's shape is first - requested. Usually this should work automatically, but some edge cases may - require manually calling this first to make sure Tensor._shape_val and + When _USE_C_SHAPES = False, this is lazily called when a tensor's shape is + first requested. Usually this should work automatically, but some edge cases + may require manually calling this first to make sure Tensor._shape_val and Tensor._handle_data are set (e.g. manually overriding _handle_data, copying a Tensor). """ @@ -3083,15 +2910,12 @@ class Graph(object): A `VersionDef`. """ # pylint: enable=line-too-long - if self._c_graph: - with c_api_util.tf_buffer() as buf: - c_api.TF_GraphVersions(self._c_graph, buf) - data = c_api.TF_GetBuffer(buf) - version_def = versions_pb2.VersionDef() - version_def.ParseFromString(compat.as_bytes(data)) - return version_def - else: - return self._graph_def_versions + with c_api_util.tf_buffer() as buf: + c_api.TF_GraphVersions(self._c_graph, buf) + data = c_api.TF_GetBuffer(buf) + version_def = versions_pb2.VersionDef() + version_def.ParseFromString(compat.as_bytes(data)) + return version_def @property def seed(self): @@ -3185,40 +3009,22 @@ class Graph(object): """ # pylint: enable=line-too-long - if self._c_graph: - with self._lock: - with c_api_util.tf_buffer() as buf: - c_api.TF_GraphToGraphDef(self._c_graph, buf) - data = c_api.TF_GetBuffer(buf) - graph = graph_pb2.GraphDef() - graph.ParseFromString(compat.as_bytes(data)) - # Strip the experimental library field iff it's empty. - if not graph.library.function: - graph.ClearField("library") - - if add_shapes: - for node in graph.node: - op = self._nodes_by_name[node.name] - if op.outputs: - node.attr["_output_shapes"].list.shape.extend( - [output.get_shape().as_proto() for output in op.outputs]) - else: - with self._lock: - graph = graph_pb2.GraphDef() - graph.versions.CopyFrom(self._graph_def_versions) - bytesize = 0 - for op_id in sorted(self._nodes_by_id): - op = self._nodes_by_id[op_id] - if from_version is None or op_id > from_version: - graph.node.extend([op.node_def]) - if op.outputs and add_shapes: - assert "_output_shapes" not in graph.node[-1].attr - graph.node[-1].attr["_output_shapes"].list.shape.extend( - [output.get_shape().as_proto() for output in op.outputs]) - bytesize += op.node_def.ByteSize() - if bytesize >= (1 << 31) or bytesize < 0: - raise ValueError("GraphDef cannot be larger than 2GB.") - self._copy_functions_to_graph_def(graph, bytesize) + with self._lock: + with c_api_util.tf_buffer() as buf: + c_api.TF_GraphToGraphDef(self._c_graph, buf) + data = c_api.TF_GetBuffer(buf) + graph = graph_pb2.GraphDef() + graph.ParseFromString(compat.as_bytes(data)) + # Strip the experimental library field iff it's empty. + if not graph.library.function: + graph.ClearField("library") + + if add_shapes: + for node in graph.node: + op = self._nodes_by_name[node.name] + if op.outputs: + node.attr["_output_shapes"].list.shape.extend( + [output.get_shape().as_proto() for output in op.outputs]) return graph, self._version def as_graph_def(self, from_version=None, add_shapes=False): @@ -3292,34 +3098,16 @@ class Graph(object): # Add function to graph # pylint: disable=protected-access - if self._c_graph: - # Handle functions created without using the C API. TODO(apassos,skyewm) - # remove this when all functions are generated using the C API by default - # as this will be unnecessary. - if not function._c_func: - serialized = function.definition.SerializeToString() - c_func = c_api.TF_FunctionImportFunctionDef(serialized) - function._c_func = c_api_util.ScopedTFFunction(c_func) - gradient = (function._grad_func._c_func.func if function._grad_func - else None) - c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient) - else: - # If there is already a function with the same name, raise an error - # if bodies are different. Else, do nothing. The C API version above - # has the same behavior. - previous = self._functions.get(name, None) - if previous: - # This check is not ideal as we can have a hash collision with only - # 32 bits in the hash, but the non C API mode is being deprecated. - # Don't bother changing it now. - if previous._hash_str == function._hash_str: - return - else: - raise ValueError("Cannot add function (%s, hash %s) to graph (%s). " - "Another function (%s, hash %s) is already defined " - "with that name (%s)" % ( - function, function._hash_str, self, - previous, previous._hash_str, name)) + # Handle functions created without using the C API. TODO(apassos,skyewm) + # remove this when all functions are generated using the C API by default + # as this will be unnecessary. + if not function._c_func: + serialized = function.definition.SerializeToString() + c_func = c_api.TF_FunctionImportFunctionDef(serialized) + function._c_func = c_api_util.ScopedTFFunction(c_func) + gradient = (function._grad_func._c_func.func if function._grad_func + else None) + c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient) # pylint: enable=protected-access self._functions[name] = function @@ -3334,6 +3122,9 @@ class Graph(object): return self._building_function # Helper functions to create operations. + @deprecated_args(None, + "Shapes are always computed; don't use the compute_shapes " + "as it has no effect.", "compute_shapes") def create_op( self, op_type, @@ -3370,8 +3161,8 @@ class Graph(object): proto). op_def: (Optional.) The `OpDef` proto that describes the `op_type` that the operation will have. - compute_shapes: (Optional.) If True, shape inference will be performed - to compute the shapes of the outputs. + compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always + computed). compute_device: (Optional.) If True, device functions will be executed to compute the device property of the Operation. @@ -3381,8 +3172,9 @@ class Graph(object): Returns: An `Operation` object. - """ + del compute_shapes + self._check_not_finalized() for idx, a in enumerate(inputs): if not isinstance(a, Tensor): @@ -3412,18 +3204,7 @@ class Graph(object): input_types=input_types, original_op=self._default_original_op, op_def=op_def) - - # Note: shapes are lazily computed with the C API enabled. - # - # TODO(skyewm): unlike in the original Python implementation, the C API - # always computes shape information (even for function calls, which the - # original Python shape inference code doesn't handle). Deprecate the - # compute_shapes argument. - if not _USE_C_API and compute_shapes: - set_shape_and_handle_data_for_outputs(ret) - - self._create_op_helper(ret, compute_shapes=compute_shapes, - compute_device=compute_device) + self._create_op_helper(ret, compute_device=compute_device) return ret def _create_op_from_tf_operation(self, c_op, compute_device=True): @@ -3457,11 +3238,8 @@ class Graph(object): self._create_op_helper(ret, compute_device=compute_device) return ret - def _create_op_helper(self, op, compute_shapes=True, compute_device=True): + def _create_op_helper(self, op, compute_device=True): """Common logic for creating an op in this graph.""" - # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed. - self._add_op(op) - # Apply any additional attributes requested. Do not overwrite any existing # attributes. for key, value in self._attr_scope_map.items(): @@ -3528,8 +3306,7 @@ class Graph(object): # (2) "is_stateful" is set in OpDef # (3) "container" attribute is in OpDef # (4) "container" attribute is None - # TODO(skyewm): remove op.op_def check when _USE_C_API is removed. - if self._container and op.op_def and op.op_def.is_stateful: + if self._container and op.op_def.is_stateful: try: container_attr = op.get_attr("container") except ValueError: @@ -3816,17 +3593,14 @@ class Graph(object): def _get_op_def(self, type): # pylint: disable=redefined-builtin """Returns the `OpDef` proto for `type`. `type` is a string.""" - if self._c_graph: - with c_api_util.tf_buffer() as buf: - # pylint: disable=protected-access - c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf) - # pylint: enable=protected-access - data = c_api.TF_GetBuffer(buf) - op_def = op_def_pb2.OpDef() - op_def.ParseFromString(compat.as_bytes(data)) - return op_def - else: - return self._registered_ops[type] + with c_api_util.tf_buffer() as buf: + # pylint: disable=protected-access + c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf) + # pylint: enable=protected-access + data = c_api.TF_GetBuffer(buf) + op_def = op_def_pb2.OpDef() + op_def.ParseFromString(compat.as_bytes(data)) + return op_def def as_default(self): """Returns a context manager that makes this `Graph` the default graph. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index e7732632f2..81355a279c 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -270,7 +270,6 @@ class OperationTest(test_util.TensorFlowTestCase): op1 = ops.Operation( ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], [dtypes.float32_ref, dtypes.float32]) - g._add_op(op1) self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) self.assertEquals([], list(op1.inputs)) ref_t, nonref_t = op1.values() @@ -279,14 +278,12 @@ class OperationTest(test_util.TensorFlowTestCase): ops._NodeDef("RefInputFloatInput", "op2"), g, [ref_t, nonref_t], [], input_types=[dtypes.float32_ref, dtypes.float32]) - g._add_op(op2) self.assertProtoEquals( "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", op2.node_def) self.assertEquals([ref_t, nonref_t], list(op2.inputs)) op3 = ops.Operation( ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) - g._add_op(op3) self.assertProtoEquals( "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", op3.node_def) |