aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-06-05 08:13:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 08:15:52 -0700
commit274f9510f68f237589df5c6a414e4b8e5ebcdba1 (patch)
treeb779bc166e890b31c03f6f3e82987daf23070096 /tensorflow
parentc0dc76a3994c743151404b1401599fefb9f37dd4 (diff)
Remove _USE_C_API staging from ops.py.
PiperOrigin-RevId: 199298594
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py1
-rw-r--r--tensorflow/contrib/graph_editor/transform.py5
-rw-r--r--tensorflow/python/framework/ops.py544
-rw-r--r--tensorflow/python/framework/ops_test.py3
4 files changed, 160 insertions, 393 deletions
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index 102bc460fd..a0dd3881a8 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -218,7 +218,6 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''):
new_control_inputs, input_types, new_original_op,
op_def)
#Use Graph's hidden methods to add the op
- to_graph._add_op(new_op) # pylint: disable=protected-access
to_graph._record_op_seen_by_control_dependencies(new_op)
for device_function in reversed(to_graph._device_function_stack):
new_op._set_device(device_function(new_op))
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 592d37b432..026a3d1200 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -189,9 +189,6 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None):
if op._original_op:
op_._original_op = op._original_op
- # Add op to the graph
- info.graph_._add_op(op_)
-
return op_, op_.outputs
@@ -492,7 +489,7 @@ class Transformer(object):
t_ = info.transformed_ts[t]
consumer_op_ = info.transformed_ops[consumer_op]
t_index_ = list(consumer_op_.inputs).index(tmp_t_)
- consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access
+ consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access
def _connect_control_inputs(self, info):
"""Connect the previously copied ops."""
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index eceea5276a..b2fd98f431 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -56,6 +56,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export
@@ -288,15 +289,8 @@ class Tensor(_TensorLike):
self._value_index = value_index
self._dtype = dtypes.as_dtype(dtype)
- if _USE_C_API:
- # This will be set by set_shape_and_handle_data_for_outputs.
- self._shape_val = None
- else:
- # The Python code requires all tensors start with a shape to support shape
- # inference on imported while loops. This isn't necessary with the C API
- # enabled because the C API provides the shapes for imported nodes.
- # TODO(skyewm): remove when _USE_C_API is removed.
- self._shape_val = tensor_shape.unknown_shape()
+ # This will be set by self.shape().
+ self._shape_val = None
# List of operations that use this Tensor as input. We maintain this list
# to easily navigate a computation graph.
@@ -384,7 +378,6 @@ class Tensor(_TensorLike):
if _USE_C_SHAPES:
self._shape_val = self._c_api_shape()
else:
- assert _USE_C_API
# Call set_shape_and_handle_data_for_outputs in topological order on all
# ops that are needed to compute self.op's shape. We do this instead of
# having set_shape_and_handle_data_for_outputs recursively call
@@ -508,8 +501,6 @@ class Tensor(_TensorLike):
else:
self._shape_val = self.shape.merge_with(shape)
- if not self._op._graph._c_graph: return
-
# Update C shape even if _USE_C_SHAPES = False, since we still want
# set_shape to be reflected in the C API graph for when we run it.
if not isinstance(shape, tensor_shape.TensorShape):
@@ -545,33 +536,14 @@ class Tensor(_TensorLike):
Returns:
A list of `Operation`s.
"""
- if self._op._c_op: # pylint: disable=protected-access
- consumer_names = c_api.TF_OperationOutputConsumers_wrapper(
- self._as_tf_output())
- # pylint: disable=protected-access
- return [
- self.graph._get_operation_by_name_unsafe(name)
- for name in consumer_names
- ]
- # pylint: enable=protected-access
- else:
- return self._consumers
-
- def _add_consumer(self, consumer):
- """Add a consumer to this tensor.
-
- Args:
- consumer: an Operation.
-
- Raises:
- TypeError: if the consumer is not an Operation.
- """
+ consumer_names = c_api.TF_OperationOutputConsumers_wrapper(
+ self._as_tf_output())
# pylint: disable=protected-access
- assert not self._op._c_op, "Tensor._add_consumer doesn't work with C API"
+ return [
+ self.graph._get_operation_by_name_unsafe(name)
+ for name in consumer_names
+ ]
# pylint: enable=protected-access
- if not isinstance(consumer, Operation):
- raise TypeError("Consumer must be an Operation: %s" % consumer)
- self._consumers.append(consumer)
def _as_node_def_input(self):
"""Return a value to use for the NodeDef "input" attribute.
@@ -594,7 +566,6 @@ class Tensor(_TensorLike):
def _as_tf_output(self):
# pylint: disable=protected-access
- assert self.op._c_op
return c_api_util.tf_output(self.op._c_op, self.value_index)
# pylint: enable=protected-access
@@ -1722,18 +1693,8 @@ class Operation(object):
"a Tensor, or IndexedSlices: %s" % c)
control_input_ops.append(control_op)
- # Don't set private fields with C API enabled to catch users who need to
- # switch to public API.
- # TODO(skyewm): delete these fields once we remove _USE_C_API
- if not self._graph._c_graph:
- self._inputs_val = list(inputs) # Defensive copy.
- self._input_types_val = input_types
- self._control_inputs_val = control_input_ops
- self._node_def_val = copy.deepcopy(node_def)
- self._op_def_val = op_def
- else:
- # This will be set by self.inputs.
- self._inputs_val = None
+ # This will be set by self.inputs.
+ self._inputs_val = None
self._id_value = self._graph._next_id() # pylint: disable=protected-access
self._original_op = original_op
@@ -1742,10 +1703,8 @@ class Operation(object):
# Initialize self._c_op.
if c_op:
- # TODO(skyewm): remove this assert when we remove USE_C_API
- assert self._graph._c_graph # pylint: disable=protected-access
self._c_op = c_op
- elif self._graph._c_graph: # pylint: disable=protected-access
+ else:
if op_def is None:
op_def = self._graph._get_op_def(node_def.op)
# TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
@@ -1754,30 +1713,19 @@ class Operation(object):
op_def, inputs, node_def.attr)
self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
control_input_ops)
- else:
- self._c_op = None
-
- # Mark that we consume the inputs. This is unnecessary and unsupported with
- # the C API enabled, since the C API tracks the tensor consumers instead.
- if not self._c_op:
- for input_tensor in self._inputs_val:
- input_tensor._add_consumer(self) # pylint: disable=protected-access
# Initialize self._outputs.
- if self._c_op:
- num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
- output_types = [
- c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i))
- for i in range(num_outputs)]
- assert output_types is not None
- elif output_types is None:
- output_types = []
- self._output_types_val = output_types
+ num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
+ output_types = [
+ c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i))
+ for i in range(num_outputs)]
self._outputs = [
Tensor(self, i, output_type)
for i, output_type in enumerate(output_types)
]
+ self._graph._add_op(self) # pylint: disable=protected-access
+
if not c_op:
self._control_flow_post_processing()
@@ -1791,7 +1739,6 @@ class Operation(object):
control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
if self._control_flow_context is not None:
self._control_flow_context.AddOp(self)
- self._recompute_node_def()
def _reconstruct_sequence_inputs(self, op_def, inputs, attrs):
"""Regroups a flat list of input tensors into scalar and sequence inputs.
@@ -1872,10 +1819,7 @@ class Operation(object):
@property
def name(self):
"""The full name of this operation."""
- if self._c_op:
- return c_api.TF_OperationName(self._c_op)
- else:
- return self._node_def_val.name
+ return c_api.TF_OperationName(self._c_op)
@property
def _id(self):
@@ -1891,10 +1835,7 @@ class Operation(object):
assigned, or an empty string if it has not been assigned to a
device.
"""
- if self._c_op:
- return c_api.TF_OperationDevice(self._c_op)
- else:
- return self._node_def_val.device
+ return c_api.TF_OperationDevice(self._c_op)
@property
def _output_types(self):
@@ -1907,28 +1848,21 @@ class Operation(object):
The length of this list indicates the number of output endpoints
of the operation.
"""
- if self._c_op:
- num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
- output_types = [
- c_api.TF_OperationOutputType(self._tf_output(i))
- for i in xrange(num_outputs)
- ]
- # TODO(iga): Remove this assert after converting to C API by default.
- # Just being a bit paranoid here.
- assert self._output_types_val == output_types
- # In all the tests we have output_types that are passed into
- # Operation.__init__ are a list of ints (which is illegal according
- # to the docstring), but input_types are instances of DType.
- # This extra assert is to catch if we ever use DType for output_types.
- if output_types:
- assert isinstance(output_types[0], int)
- return output_types
- else:
- return self._output_types_val
+ num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
+ output_types = [
+ c_api.TF_OperationOutputType(self._tf_output(i))
+ for i in xrange(num_outputs)
+ ]
+ # In all the tests we have output_types that are passed into
+ # Operation.__init__ are a list of ints (which is illegal according
+ # to the docstring), but input_types are instances of DType.
+ # This extra assert is to catch if we ever use DType for output_types.
+ if output_types:
+ assert isinstance(output_types[0], int)
+ return output_types
def _tf_output(self, output_idx):
"""Create and return a new TF_Output for output_idx'th output of this op."""
- assert self._c_op
tf_output = c_api.TF_Output()
tf_output.oper = self._c_op
tf_output.index = output_idx
@@ -1936,7 +1870,6 @@ class Operation(object):
def _tf_input(self, input_idx):
"""Create and return a new TF_Input for input_idx'th input of this op."""
- assert self._c_op
tf_input = c_api.TF_Input()
tf_input.oper = self._c_op
tf_input.index = input_idx
@@ -1948,47 +1881,12 @@ class Operation(object):
Args:
device: string or device.. The device to set.
"""
- if self._c_op:
- c_api.SetRequestedDevice(
- self._graph._c_graph, # pylint: disable=protected-access
- self._c_op, # pylint: disable=protected-access
- compat.as_str(_device_string(device)))
- else:
- self._node_def_val.device = _device_string(device)
-
- def _add_input(self, tensor, dtype=None):
- """Add a new input to this operation.
-
- Args:
- tensor: the Tensor to add as an input.
- dtype: tf.DType: type of the input; defaults to
- the tensor's dtype.
+ c_api.SetRequestedDevice(
+ self._graph._c_graph, # pylint: disable=protected-access
+ self._c_op, # pylint: disable=protected-access
+ compat.as_str(_device_string(device)))
- Raises:
- TypeError: if tensor is not a Tensor,
- or if input tensor type is not convertible to dtype.
- ValueError: if the Tensor is from a different graph.
- """
- assert not self._c_op, (
- "Operation._add_input doesn't work with C API")
- if not isinstance(tensor, Tensor):
- raise TypeError("tensor must be a Tensor: %s" % tensor)
- _assert_same_graph(self, tensor)
- if dtype is None:
- dtype = tensor.dtype
- else:
- dtype = dtypes.as_dtype(dtype)
- if not dtype.is_compatible_with(tensor.dtype):
- raise TypeError(
- "Cannot convert a tensor of type %s to an input of type %s" %
- (tensor.dtype.name, dtype.name))
- self._inputs_val.append(tensor)
- self._input_types_val.append(dtype)
- tensor._add_consumer(self) # pylint: disable=protected-access
- self._recompute_node_def()
-
- # TODO(skyewm): Remove `update_dtype` when we enable the C API.
- def _update_input(self, index, tensor, update_dtype=True):
+ def _update_input(self, index, tensor):
"""Update the input to this operation at the given index.
NOTE: This is for TF internal use only. Please don't use it.
@@ -1996,7 +1894,6 @@ class Operation(object):
Args:
index: the index of the input to update.
tensor: the Tensor to be used as the input at the given index.
- update_dtype: If `False`, the type for this input is not updated.
Raises:
TypeError: if tensor is not a Tensor,
@@ -2013,20 +1910,12 @@ class Operation(object):
if not _USE_C_SHAPES:
set_shape_and_handle_data_for_outputs(self)
- if self._c_op:
- # Reset cached inputs.
- self._inputs_val = None
- c_api.UpdateEdge(
- self._graph._c_graph, # pylint: disable=protected-access
- tensor._as_tf_output(), # pylint: disable=protected-access
- self._tf_input(index))
- else:
- self._inputs_val[index].consumers().remove(self)
- self._inputs_val[index] = tensor
- if update_dtype:
- self._input_types_val[index] = tensor.dtype
- tensor._add_consumer(self) # pylint: disable=protected-access
- self._recompute_node_def()
+ # Reset cached inputs.
+ self._inputs_val = None
+ c_api.UpdateEdge(
+ self._graph._c_graph, # pylint: disable=protected-access
+ tensor._as_tf_output(), # pylint: disable=protected-access
+ self._tf_input(index))
def _add_control_inputs(self, ops):
"""Add a list of new control inputs to this operation.
@@ -2038,19 +1927,10 @@ class Operation(object):
TypeError: if ops is not a list of Operations.
ValueError: if any op in ops is from a different graph.
"""
- if self._c_op:
- for op in ops:
- if not isinstance(op, Operation):
- raise TypeError("op must be an Operation: %s" % op)
- c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
- else:
- if ops:
- for op in ops:
- if not isinstance(op, Operation):
- raise TypeError("op must be an Operation: %s" % op)
- _assert_same_graph(self, op)
- self._control_inputs_val.append(op)
- self._recompute_node_def()
+ for op in ops:
+ if not isinstance(op, Operation):
+ raise TypeError("op must be an Operation: %s" % op)
+ c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
def _add_control_input(self, op):
"""Add a new control input to this operation.
@@ -2062,33 +1942,13 @@ class Operation(object):
TypeError: if op is not an Operation.
ValueError: if op is from a different graph.
"""
- if self._c_op:
- if not isinstance(op, Operation):
- raise TypeError("op must be an Operation: %s" % op)
- c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
- else:
- self._add_control_inputs([op])
+ if not isinstance(op, Operation):
+ raise TypeError("op must be an Operation: %s" % op)
+ c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
def _remove_all_control_inputs(self):
"""Removes any control inputs to this operation."""
- if self._c_op:
- c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access
- else:
- del self.control_inputs[:]
-
- # Methods below are used when building the NodeDef and Graph proto.
- def _recompute_node_def(self):
- # TODO(skyewm): remove this function when we switch to C API
- if self._c_op: return
-
- del self._node_def_val.input[:]
- # pylint: disable=protected-access
- self._node_def_val.input.extend(
- [t._as_node_def_input() for t in self._inputs_val])
- # pylint: enable=protected-access
- if self._control_inputs_val:
- self._node_def_val.input.extend(
- ["^%s" % op.name for op in self._control_inputs_val])
+ c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access
def __str__(self):
return str(self.node_def)
@@ -2129,19 +1989,16 @@ class Operation(object):
@property
def inputs(self):
"""The list of `Tensor` objects representing the data inputs of this op."""
- if self._c_op:
- if self._inputs_val is None:
- tf_outputs = c_api.GetOperationInputs(self._c_op)
- # pylint: disable=protected-access
- retval = [
- self.graph._get_tensor_by_tf_output(tf_output)
- for tf_output in tf_outputs
- ]
- # pylint: enable=protected-access
- self._inputs_val = Operation._InputList(retval)
- return self._inputs_val
- else:
- return Operation._InputList(self._inputs_val)
+ if self._inputs_val is None:
+ tf_outputs = c_api.GetOperationInputs(self._c_op)
+ # pylint: disable=protected-access
+ retval = [
+ self.graph._get_tensor_by_tf_output(tf_output)
+ for tf_output in tf_outputs
+ ]
+ # pylint: enable=protected-access
+ self._inputs_val = Operation._InputList(retval)
+ return self._inputs_val
@property
def _inputs(self):
@@ -2155,15 +2012,12 @@ class Operation(object):
@property
def _input_types(self):
- if self._c_op:
- num_inputs = c_api.TF_OperationNumInputs(self._c_op)
- input_types = [
- dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
- for i in xrange(num_inputs)
- ]
- return input_types
- else:
- return self._input_types_val
+ num_inputs = c_api.TF_OperationNumInputs(self._c_op)
+ input_types = [
+ dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
+ for i in xrange(num_inputs)
+ ]
+ return input_types
@_input_types.setter
def _input_types(self, value):
@@ -2183,16 +2037,13 @@ class Operation(object):
A list of `Operation` objects.
"""
- if self._c_op:
- control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
- # pylint: disable=protected-access
- return [
- self.graph._get_operation_by_name_unsafe(
- c_api.TF_OperationName(c_op)) for c_op in control_c_ops
- ]
- # pylint: enable=protected-access
- else:
- return self._control_inputs_val
+ control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
+ # pylint: disable=protected-access
+ return [
+ self.graph._get_operation_by_name_unsafe(
+ c_api.TF_OperationName(c_op)) for c_op in control_c_ops
+ ]
+ # pylint: enable=protected-access
@property
def _control_outputs(self):
@@ -2205,18 +2056,13 @@ class Operation(object):
A list of `Operation` objects.
"""
- if self._c_op:
- control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
- # pylint: disable=protected-access
- return [
- self.graph._get_operation_by_name_unsafe(
- c_api.TF_OperationName(c_op)) for c_op in control_c_ops
- ]
- # pylint: enable=protected-access
- else:
- # TODO(apassos) this should be less inefficient.
- return [o for o in self._graph.get_operations()
- if self in o.control_inputs]
+ control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
+ # pylint: disable=protected-access
+ return [
+ self.graph._get_operation_by_name_unsafe(
+ c_api.TF_OperationName(c_op)) for c_op in control_c_ops
+ ]
+ # pylint: enable=protected-access
@property
def _control_inputs(self):
@@ -2240,11 +2086,7 @@ class Operation(object):
@property
def type(self):
"""The type of the op (e.g. `"MatMul"`)."""
- if self._c_op:
- op_type = c_api.TF_OperationOpType(self._c_op)
- return op_type
- else:
- return self._node_def_val.op
+ return c_api.TF_OperationOpType(self._c_op)
@property
def graph(self):
@@ -2262,15 +2104,12 @@ class Operation(object):
protocol buffer.
"""
# pylint: enable=line-too-long
- if self._c_op:
- with c_api_util.tf_buffer() as buf:
- c_api.TF_OperationToNodeDef(self._c_op, buf)
- data = c_api.TF_GetBuffer(buf)
- node_def = node_def_pb2.NodeDef()
- node_def.ParseFromString(compat.as_bytes(data))
- return node_def
- else:
- return self._node_def_val
+ with c_api_util.tf_buffer() as buf:
+ c_api.TF_OperationToNodeDef(self._c_op, buf)
+ data = c_api.TF_GetBuffer(buf)
+ node_def = node_def_pb2.NodeDef()
+ node_def.ParseFromString(compat.as_bytes(data))
+ return node_def
@property
def _node_def(self):
@@ -2289,10 +2128,7 @@ class Operation(object):
protocol buffer.
"""
# pylint: enable=line-too-long
- if self._c_op:
- return self._graph._get_op_def(self.type)
- else:
- return self._op_def_val
+ return self._graph._get_op_def(self.type)
@property
def _op_def(self):
@@ -2318,17 +2154,14 @@ class Operation(object):
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
- if self._c_op:
- buf = c_api.TF_NewBufferFromString(
- compat.as_bytes(attr_value.SerializeToString()))
- try:
- # pylint: disable=protected-access
- c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
- # pylint: enable=protected-access
- finally:
- c_api.TF_DeleteBuffer(buf)
- else:
- self._node_def_val.attr[attr_name].CopyFrom(attr_value)
+ buf = c_api.TF_NewBufferFromString(
+ compat.as_bytes(attr_value.SerializeToString()))
+ try:
+ # pylint: disable=protected-access
+ c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
+ # pylint: enable=protected-access
+ finally:
+ c_api.TF_DeleteBuffer(buf)
def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`.
@@ -2343,21 +2176,15 @@ class Operation(object):
ValueError: If this op does not have an attr with the given `name`.
"""
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
- if self._c_op:
- try:
- with c_api_util.tf_buffer() as buf:
- c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
- data = c_api.TF_GetBuffer(buf)
- except errors.InvalidArgumentError as e:
- # Convert to ValueError for backwards compatibility.
- raise ValueError(str(e))
- x = attr_value_pb2.AttrValue()
- x.ParseFromString(data)
- else:
- if name not in self._node_def_val.attr:
- raise ValueError(
- "No attr named '" + name + "' in " + str(self._node_def_val))
- x = self._node_def_val.attr[name]
+ try:
+ with c_api_util.tf_buffer() as buf:
+ c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
+ data = c_api.TF_GetBuffer(buf)
+ except errors.InvalidArgumentError as e:
+ # Convert to ValueError for backwards compatibility.
+ raise ValueError(str(e))
+ x = attr_value_pb2.AttrValue()
+ x.ParseFromString(data)
# Treat an empty oneof value as an empty list.
if not x.WhichOneof("value"):
@@ -2577,9 +2404,9 @@ def _set_shape_and_handle_data_for_outputs_c_api(op):
def set_shape_and_handle_data_for_outputs(op):
"""Set the shapes and resource handle data for op's outputs.
- When _USE_C_API = True, this is lazily called when a tensor's shape is first
- requested. Usually this should work automatically, but some edge cases may
- require manually calling this first to make sure Tensor._shape_val and
+ When _USE_C_SHAPES = False, this is lazily called when a tensor's shape is
+ first requested. Usually this should work automatically, but some edge cases
+ may require manually calling this first to make sure Tensor._shape_val and
Tensor._handle_data are set (e.g. manually overriding _handle_data, copying a
Tensor).
"""
@@ -3083,15 +2910,12 @@ class Graph(object):
A `VersionDef`.
"""
# pylint: enable=line-too-long
- if self._c_graph:
- with c_api_util.tf_buffer() as buf:
- c_api.TF_GraphVersions(self._c_graph, buf)
- data = c_api.TF_GetBuffer(buf)
- version_def = versions_pb2.VersionDef()
- version_def.ParseFromString(compat.as_bytes(data))
- return version_def
- else:
- return self._graph_def_versions
+ with c_api_util.tf_buffer() as buf:
+ c_api.TF_GraphVersions(self._c_graph, buf)
+ data = c_api.TF_GetBuffer(buf)
+ version_def = versions_pb2.VersionDef()
+ version_def.ParseFromString(compat.as_bytes(data))
+ return version_def
@property
def seed(self):
@@ -3185,40 +3009,22 @@ class Graph(object):
"""
# pylint: enable=line-too-long
- if self._c_graph:
- with self._lock:
- with c_api_util.tf_buffer() as buf:
- c_api.TF_GraphToGraphDef(self._c_graph, buf)
- data = c_api.TF_GetBuffer(buf)
- graph = graph_pb2.GraphDef()
- graph.ParseFromString(compat.as_bytes(data))
- # Strip the experimental library field iff it's empty.
- if not graph.library.function:
- graph.ClearField("library")
-
- if add_shapes:
- for node in graph.node:
- op = self._nodes_by_name[node.name]
- if op.outputs:
- node.attr["_output_shapes"].list.shape.extend(
- [output.get_shape().as_proto() for output in op.outputs])
- else:
- with self._lock:
- graph = graph_pb2.GraphDef()
- graph.versions.CopyFrom(self._graph_def_versions)
- bytesize = 0
- for op_id in sorted(self._nodes_by_id):
- op = self._nodes_by_id[op_id]
- if from_version is None or op_id > from_version:
- graph.node.extend([op.node_def])
- if op.outputs and add_shapes:
- assert "_output_shapes" not in graph.node[-1].attr
- graph.node[-1].attr["_output_shapes"].list.shape.extend(
- [output.get_shape().as_proto() for output in op.outputs])
- bytesize += op.node_def.ByteSize()
- if bytesize >= (1 << 31) or bytesize < 0:
- raise ValueError("GraphDef cannot be larger than 2GB.")
- self._copy_functions_to_graph_def(graph, bytesize)
+ with self._lock:
+ with c_api_util.tf_buffer() as buf:
+ c_api.TF_GraphToGraphDef(self._c_graph, buf)
+ data = c_api.TF_GetBuffer(buf)
+ graph = graph_pb2.GraphDef()
+ graph.ParseFromString(compat.as_bytes(data))
+ # Strip the experimental library field iff it's empty.
+ if not graph.library.function:
+ graph.ClearField("library")
+
+ if add_shapes:
+ for node in graph.node:
+ op = self._nodes_by_name[node.name]
+ if op.outputs:
+ node.attr["_output_shapes"].list.shape.extend(
+ [output.get_shape().as_proto() for output in op.outputs])
return graph, self._version
def as_graph_def(self, from_version=None, add_shapes=False):
@@ -3292,34 +3098,16 @@ class Graph(object):
# Add function to graph
# pylint: disable=protected-access
- if self._c_graph:
- # Handle functions created without using the C API. TODO(apassos,skyewm)
- # remove this when all functions are generated using the C API by default
- # as this will be unnecessary.
- if not function._c_func:
- serialized = function.definition.SerializeToString()
- c_func = c_api.TF_FunctionImportFunctionDef(serialized)
- function._c_func = c_api_util.ScopedTFFunction(c_func)
- gradient = (function._grad_func._c_func.func if function._grad_func
- else None)
- c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
- else:
- # If there is already a function with the same name, raise an error
- # if bodies are different. Else, do nothing. The C API version above
- # has the same behavior.
- previous = self._functions.get(name, None)
- if previous:
- # This check is not ideal as we can have a hash collision with only
- # 32 bits in the hash, but the non C API mode is being deprecated.
- # Don't bother changing it now.
- if previous._hash_str == function._hash_str:
- return
- else:
- raise ValueError("Cannot add function (%s, hash %s) to graph (%s). "
- "Another function (%s, hash %s) is already defined "
- "with that name (%s)" % (
- function, function._hash_str, self,
- previous, previous._hash_str, name))
+ # Handle functions created without using the C API. TODO(apassos,skyewm)
+ # remove this when all functions are generated using the C API by default
+ # as this will be unnecessary.
+ if not function._c_func:
+ serialized = function.definition.SerializeToString()
+ c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ function._c_func = c_api_util.ScopedTFFunction(c_func)
+ gradient = (function._grad_func._c_func.func if function._grad_func
+ else None)
+ c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
# pylint: enable=protected-access
self._functions[name] = function
@@ -3334,6 +3122,9 @@ class Graph(object):
return self._building_function
# Helper functions to create operations.
+ @deprecated_args(None,
+ "Shapes are always computed; don't use the compute_shapes "
+ "as it has no effect.", "compute_shapes")
def create_op(
self,
op_type,
@@ -3370,8 +3161,8 @@ class Graph(object):
proto).
op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
the operation will have.
- compute_shapes: (Optional.) If True, shape inference will be performed
- to compute the shapes of the outputs.
+ compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
+ computed).
compute_device: (Optional.) If True, device functions will be executed
to compute the device property of the Operation.
@@ -3381,8 +3172,9 @@ class Graph(object):
Returns:
An `Operation` object.
-
"""
+ del compute_shapes
+
self._check_not_finalized()
for idx, a in enumerate(inputs):
if not isinstance(a, Tensor):
@@ -3412,18 +3204,7 @@ class Graph(object):
input_types=input_types,
original_op=self._default_original_op,
op_def=op_def)
-
- # Note: shapes are lazily computed with the C API enabled.
- #
- # TODO(skyewm): unlike in the original Python implementation, the C API
- # always computes shape information (even for function calls, which the
- # original Python shape inference code doesn't handle). Deprecate the
- # compute_shapes argument.
- if not _USE_C_API and compute_shapes:
- set_shape_and_handle_data_for_outputs(ret)
-
- self._create_op_helper(ret, compute_shapes=compute_shapes,
- compute_device=compute_device)
+ self._create_op_helper(ret, compute_device=compute_device)
return ret
def _create_op_from_tf_operation(self, c_op, compute_device=True):
@@ -3457,11 +3238,8 @@ class Graph(object):
self._create_op_helper(ret, compute_device=compute_device)
return ret
- def _create_op_helper(self, op, compute_shapes=True, compute_device=True):
+ def _create_op_helper(self, op, compute_device=True):
"""Common logic for creating an op in this graph."""
- # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
- self._add_op(op)
-
# Apply any additional attributes requested. Do not overwrite any existing
# attributes.
for key, value in self._attr_scope_map.items():
@@ -3528,8 +3306,7 @@ class Graph(object):
# (2) "is_stateful" is set in OpDef
# (3) "container" attribute is in OpDef
# (4) "container" attribute is None
- # TODO(skyewm): remove op.op_def check when _USE_C_API is removed.
- if self._container and op.op_def and op.op_def.is_stateful:
+ if self._container and op.op_def.is_stateful:
try:
container_attr = op.get_attr("container")
except ValueError:
@@ -3816,17 +3593,14 @@ class Graph(object):
def _get_op_def(self, type): # pylint: disable=redefined-builtin
"""Returns the `OpDef` proto for `type`. `type` is a string."""
- if self._c_graph:
- with c_api_util.tf_buffer() as buf:
- # pylint: disable=protected-access
- c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
- # pylint: enable=protected-access
- data = c_api.TF_GetBuffer(buf)
- op_def = op_def_pb2.OpDef()
- op_def.ParseFromString(compat.as_bytes(data))
- return op_def
- else:
- return self._registered_ops[type]
+ with c_api_util.tf_buffer() as buf:
+ # pylint: disable=protected-access
+ c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
+ # pylint: enable=protected-access
+ data = c_api.TF_GetBuffer(buf)
+ op_def = op_def_pb2.OpDef()
+ op_def.ParseFromString(compat.as_bytes(data))
+ return op_def
def as_default(self):
"""Returns a context manager that makes this `Graph` the default graph.
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index e7732632f2..81355a279c 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -270,7 +270,6 @@ class OperationTest(test_util.TensorFlowTestCase):
op1 = ops.Operation(
ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
[dtypes.float32_ref, dtypes.float32])
- g._add_op(op1)
self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
self.assertEquals([], list(op1.inputs))
ref_t, nonref_t = op1.values()
@@ -279,14 +278,12 @@ class OperationTest(test_util.TensorFlowTestCase):
ops._NodeDef("RefInputFloatInput", "op2"),
g, [ref_t, nonref_t], [],
input_types=[dtypes.float32_ref, dtypes.float32])
- g._add_op(op2)
self.assertProtoEquals(
"op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
op2.node_def)
self.assertEquals([ref_t, nonref_t], list(op2.inputs))
op3 = ops.Operation(
ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
- g._add_op(op3)
self.assertProtoEquals(
"op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
op3.node_def)