aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/op_def_library.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/op_def_library.py')
-rw-r--r--tensorflow/python/ops/op_def_library.py640
1 files changed, 640 insertions, 0 deletions
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
new file mode 100644
index 0000000000..5947b6df89
--- /dev/null
+++ b/tensorflow/python/ops/op_def_library.py
@@ -0,0 +1,640 @@
+"""Class to hold a library of OpDefs and use it to create Brain operations."""
+
+import numbers
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.core.framework import tensor_pb2
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types as types_lib
+from tensorflow.python.ops import constant_op
+from tensorflow.python.platform import logging
+
+
+def _Attr(op_def, name):
+ for attr in op_def.attr:
+ if attr.name == name:
+ return attr
+ raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" %
+ (op_def.name, name))
+
+
+def _AttrValue(attr_protos, name):
+ if name in attr_protos:
+ return attr_protos[name]
+ raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." %
+ (name, attr_protos))
+
+
+def _SatisfiesTypeConstraint(dtype, attr_def):
+ if attr_def.HasField("allowed_values"):
+ allowed_list = attr_def.allowed_values.list.type
+ if dtype not in allowed_list:
+ raise TypeError(
+ "DataType %s for attr '%s' not in list of allowed values: %s" %
+ (types_lib.as_dtype(dtype).name, attr_def.name,
+ ", ".join(types_lib.as_dtype(x).name for x in allowed_list)))
+
+
+def _IsListParameter(arg):
+ if arg.number_attr:
+ return True
+ elif arg.type_list_attr:
+ return True
+ return False
+
+
+def _NumTypeFields(arg):
+ num = 0
+ if arg.type != types_pb2.DT_INVALID: num += 1
+ if arg.type_attr: num += 1
+ if arg.type_list_attr: num += 1
+ return num
+
+
+def _IsListValue(v):
+ return isinstance(v, (list, tuple))
+
+
+def _Flatten(l):
+ """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""
+ # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]
+ l_of_l = [x if _IsListValue(x) else [x] for x in l]
+ # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]
+ return [item for sublist in l_of_l for item in sublist]
+
+
+def _Restructure(l, structure):
+ """Returns the elements of list l structured according to the given structure.
+
+ A structure is represented by a list whose elements are either
+ `None` or a non-negative integer. `None` corresponds to a single
+ element in the output list, and an integer N corresponds to a nested
+ list of length N.
+
+ The function returns a data structure whose shape is given by
+ `structure`, and whose elements are taken from `l`. If `structure`
+ is a singleton, the function returns the single data structure
+ implied by the 0th element of `structure`. For example:
+
+ _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])
+ -> ["foo", ["bar", "baz"], "qux"]
+
+ _Restructure(["foo"], [None]) -> "foo"
+
+ _Restructure(["foo"], [1]) -> ["foo"]
+
+ _Restructure([], [0]) -> []
+
+ Args:
+ l: A list.
+ structure: A list whose elements are either `None` or a non-negative
+ integer.
+
+ Returns:
+ The elements of `l`, restructured according to `structure`. If
+ `structure` is a list of length 1, this function returns the
+ single data structure implied by `structure[0]`.
+
+ """
+ result = []
+ current_index = 0
+ for element in structure:
+ if element is None:
+ result.append(l[current_index])
+ current_index += 1
+ else:
+ result.append(l[current_index:current_index+element])
+ current_index += element
+
+ if len(result) == 1:
+ return result[0]
+ else:
+ return tuple(result)
+
+
+def _MakeFloat(v, arg_name):
+ if not isinstance(v, numbers.Real):
+ raise TypeError("Expected float for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return float(v)
+
+
+def _MakeInt(v, arg_name):
+ if isinstance(v, basestring):
+ raise TypeError("Expected int for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ try:
+ return int(v)
+ except (ValueError, TypeError):
+ raise TypeError("Expected int for argument '%s' not %s." %
+ (arg_name, repr(v)))
+
+
+def _MakeStr(v, arg_name):
+ if not isinstance(v, basestring):
+ raise TypeError("Expected string for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return str(v) # Convert unicode strings to bytes.
+
+
+def _MakeBool(v, arg_name):
+ if not isinstance(v, bool):
+ raise TypeError("Expected bool for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return v
+
+
+def _MakeType(v, attr_def):
+ try:
+ v = types_lib.as_dtype(v)
+ except TypeError:
+ raise TypeError("Expected DataType for argument '%s' not %s." %
+ (attr_def.name, repr(v)))
+ i = v.as_datatype_enum
+ _SatisfiesTypeConstraint(i, attr_def)
+ return i
+
+
+def _MakeShape(v, arg_name):
+ """Convert v into a TensorShapeProto."""
+ # Args:
+ # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
+ # arg_name: String, for error messages.
+
+ # Returns:
+ # A TensorShapeProto.
+ if isinstance(v, tensor_shape_pb2.TensorShapeProto):
+ for d in v.dim:
+ if d.name:
+ logging.warning("Warning: TensorShapeProto with a named dimension: %s",
+ str(v))
+ break
+ return v
+ s = tensor_shape.as_shape(v)
+ ret = tensor_shape_pb2.TensorShapeProto()
+ for i in s.as_dimension_list():
+ ret.dim.add(size = i)
+ return ret
+
+
+def _MakeTensor(v, arg_name):
+ """Ensure v is a TensorProto."""
+ if isinstance(v, tensor_pb2.TensorProto):
+ return v
+ raise TypeError(
+ "Don't know how to convert %s to a TensorProto for argument '%s'" %
+ (repr(v), arg_name))
+
+
+class _OpInfo(object):
+ """All per-Op state we would like to precompute/validate."""
+
+ def __init__(self, op_def):
+ self.op_def = op_def
+ # TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it
+ # here, instead of these checks.
+ for arg in list(op_def.input_arg) + list(op_def.output_arg):
+ num_type_fields = _NumTypeFields(arg)
+ if num_type_fields != 1:
+ raise TypeError("Arg '%s' of '%s' must have one type field not %d" %
+ (arg.name, op_def.name, num_type_fields))
+ if arg.type_attr:
+ attr_type = _Attr(op_def, arg.type_attr).type
+ if attr_type != "type":
+ raise TypeError("Attr '%s' of '%s' used as a type_attr "
+ "but has type %s" %
+ (arg.type_attr, op_def.name, attr_type))
+ if arg.type_list_attr:
+ attr_type = _Attr(op_def, arg.type_list_attr).type
+ if attr_type != "list(type)":
+ raise TypeError(
+ "Attr '%s' of '%s' used as a type_list_attr but has type %s" %
+ (arg.type_attr, op_def.name, attr_type))
+ if arg.number_attr:
+ attr_type = _Attr(op_def, arg.number_attr).type
+ if attr_type != "int":
+ raise TypeError(
+ "Attr '%s' of '%s' used as a number_attr but has type %s" %
+ (arg.number_attr, op_def.name, attr_type))
+
+
+class OpDefLibrary(object):
+ """Holds a collection of OpDefs, can add the corresponding Ops to a graph."""
+
+ def __init__(self):
+ self._ops = {}
+
+ def add_op(self, op_def):
+ """Register an OpDef. May call apply_op with the name afterwards."""
+ if not isinstance(op_def, op_def_pb2.OpDef):
+ raise TypeError("%s is %s, not an op_def_pb2.OpDef" %
+ (op_def, type(op_def)))
+ if not op_def.name:
+ raise ValueError("%s missing name." % op_def)
+ if op_def.name in self._ops:
+ raise RuntimeError("Op name %s registered twice." % op_def.name)
+ self._ops[op_def.name] = _OpInfo(op_def)
+
+ def add_op_list(self, op_list):
+ """Register the OpDefs from an OpList."""
+ if not isinstance(op_list, op_def_pb2.OpList):
+ raise TypeError("%s is %s, not an op_def_pb2.OpList" %
+ (op_list, type(op_list)))
+ for op_def in op_list.op:
+ self.add_op(op_def)
+
+ def apply_op(self, op_type_name, g=None, name=None, **keywords):
+ # pylint: disable=g-doc-args
+ """Add a node invoking a registered Op to a graph.
+
+ Config proto extensions must be provided via the 'ext' keyword argument.
+ Example usage:
+ # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
+ # will convert to a Tensor.
+ op_def_library.apply_op("op", input1=input1, input2=input2)
+ # If none of the inputs are Tensors and your session doesn't have a
+ # default graph, you will have to specify the graph.
+ op_def_library.apply_op("op", input1=input1, g=g)
+ # Can specify a node name.
+ op_def_library.apply_op("op", input1=input1, name="node_name")
+ # Must use keyword arguments, with the names specified in the OpDef.
+ op_def_library.apply_op("op", input_name=input, attr_name=attr)
+
+ All attrs must either be inferred from an input or specified.
+ (If inferred, the attr must not be specified.) If an attr has a default
+ value specified in the Op's OpDef, then you may pass None as the value
+ of that attr to get the default.
+
+ Args:
+ op_type_name: string. Must match the name field of a registered Op.
+ g: The graph context (optional)
+ name: string. Optional name of the created op.
+ **keywords: input Tensor and attr arguments specified by name,
+ and optional parameters to pass when constructing the Operation.
+
+ Returns:
+ The Tensor(s) representing the output of the operation, or the Operation
+ itself if there are no outputs.
+
+ Raises:
+ RuntimeError: On some errors.
+ TypeError: On some errors.
+ ValueError: On some errors.
+ """
+ op_info = self._ops.get(op_type_name, None)
+ if op_info is None:
+ raise RuntimeError("Unrecognized Op name " + op_type_name)
+ op_def = op_info.op_def
+
+ # Determine the graph context.
+ try:
+ # Need to flatten all the arguments into a list.
+ # pylint: disable=protected-access
+ g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g)
+ # pyline: enable=protected-access
+ except AssertionError as e:
+ raise RuntimeError(
+ "Need to specify g=graph to Op '%s' (could not determine graph due "
+ "to: %s)" % (op_type_name, e.message))
+
+ # Default name if not specified.
+ if name is None:
+ name = op_type_name
+
+ # Requires that op_def has passed validation (using the C++
+ # ValidateOpDef() from ../framework/op_def_util.h).
+ attrs = {}
+ inputs = []
+ input_types = []
+ with g.as_default(), ops.name_scope(name) as scope:
+
+ # Perform input type inference
+ inferred_from = {}
+ for input_arg in op_def.input_arg:
+ input_name = input_arg.name
+ if input_name in keywords:
+ values = keywords.pop(input_name)
+ elif input_name + "_" in keywords:
+ # Handle the case where the name is a keyword or built-in
+ # for Python so we use the name + _ instead.
+ input_name += "_"
+ values = keywords.pop(input_name)
+ else:
+ raise TypeError("No argument for input " + input_name)
+
+ # Goals:
+ # * Convert values to Tensors if it contains constants.
+ # * Verify that values is a list if that matches the input_arg's
+ # type.
+ # * If the input_arg's type is determined by attrs, either set
+ # those attrs and validate those attr values are legal (if
+ # they have not yet been set) or validate the input matches
+ # the type indicated by the attrs (if they have already been
+ # inferred via an earlier input).
+ # * If the input_arg has an explicit type, make sure the input
+ # conforms.
+
+ if _IsListParameter(input_arg):
+ if not _IsListValue(values):
+ raise TypeError(
+ "Expected list for '%s' argument to '%s' Op, not %s." %
+ (input_name, op_type_name, values))
+ # In cases where we expect all elements of the list to have the
+ # same dtype, try to cast non-Tensor elements to that type.
+ dtype = None
+ if input_arg.type != types_pb2.DT_INVALID:
+ dtype = input_arg.type
+ elif input_arg.number_attr:
+ if input_arg.type_attr in attrs:
+ dtype = attrs[input_arg.type_attr]
+ else:
+ for t in values:
+ if isinstance(t, ops.Tensor):
+ dtype = t.dtype
+ break
+
+ try:
+ values = ops.convert_n_to_tensor_or_indexed_slices(
+ values, name=input_arg.name,
+ dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None)
+ except (TypeError, ValueError):
+ assert dtype is not None, "Should not fail if dtype is None"
+ assert input_arg.number_attr, "Should be number_attr case"
+ # What types does the conversion function think values have?
+ values = ops.convert_n_to_tensor_or_indexed_slices(values)
+ observed = ", ".join(v.dtype.base_dtype.name for v in values)
+
+ prefix = (
+ "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
+ (input_name, op_type_name, observed))
+ if input_arg.type != types_pb2.DT_INVALID:
+ raise TypeError("%s that do not match expected type %s." %
+ (prefix, types_lib.as_dtype(dtype).name))
+ elif input_arg.type_attr in attrs:
+ raise TypeError("%s that do not match type %s inferred from "
+ "earlier arguments." %
+ (prefix, types_lib.as_dtype(dtype).name))
+ else:
+ raise TypeError("%s that don't all match." % prefix)
+
+ types = [x.dtype for x in values]
+ inputs.extend(values)
+ else:
+ # In cases where we have an expected type, try to convert non-Tensor
+ # arguments to that type.
+ dtype = None
+ if input_arg.type != types_pb2.DT_INVALID:
+ dtype = input_arg.type
+ elif input_arg.type_attr in attrs:
+ dtype = attrs[input_arg.type_attr]
+
+ try:
+ values = ops.convert_to_tensor(
+ values, name=input_arg.name, dtype=dtype)
+ except ValueError:
+ # What type does convert_to_tensor think it has?
+ observed = ops.convert_to_tensor(values).dtype.name
+ prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
+ (input_name, op_type_name, observed))
+ if input_arg.type != types_pb2.DT_INVALID:
+ raise TypeError("%s expected type of %s." %
+ (prefix, types_lib.as_dtype(input_arg.type).name))
+ else:
+ raise TypeError(
+ "%s type %s of argument '%s'." %
+ (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name,
+ inferred_from[input_arg.type_attr]))
+
+ types = [values.dtype]
+ inputs.append(values)
+ base_types = [x.base_dtype for x in types]
+
+ if input_arg.number_attr:
+ # <number-attr> * <type> or <number-attr> * <type-attr>
+ if input_arg.number_attr in attrs:
+ if len(values) != attrs[input_arg.number_attr]:
+ raise ValueError(
+ "List argument '%s' to '%s' Op with length %d must match "
+ "length %d of argument '%s'." %
+ (input_name, op_type_name, len(values),
+ attrs[input_arg.number_attr],
+ inferred_from[input_arg.number_attr]))
+ else:
+ attrs[input_arg.number_attr] = len(values)
+ inferred_from[input_arg.number_attr] = input_name
+ num_attr = _Attr(op_def, input_arg.number_attr)
+ if num_attr.has_minimum and len(values) < num_attr.minimum:
+ raise ValueError(
+ "List argument '%s' to '%s' Op with length %d shorter "
+ "than minimum length %d." %
+ (input_name, op_type_name, len(values), num_attr.minimum))
+ # All tensors must have the same base type.
+ if any([bt != base_types[0] for bt in base_types]):
+ raise TypeError(
+ "All tensors passed to '%s' of '%s' Op "
+ "must have the same type." %
+ (input_name, op_type_name))
+ if input_arg.type != types_pb2.DT_INVALID:
+ # <number-attr> * <type> case
+ if base_types and base_types[0] != input_arg.type:
+ assert False, "Unreachable"
+ elif input_arg.type_attr in attrs:
+ # <number-attr> * <type-attr> case, where <type-attr> already
+ # has an inferred value.
+ if base_types and base_types[0] != attrs[input_arg.type_attr]:
+ assert False, "Unreachable"
+ else:
+ # <number-attr> * <type-attr> case, where we are now setting
+ # the <type-attr> based on this input
+ if not base_types:
+ raise TypeError(
+ "Don't know how to infer type variable from empty input "
+ "list passed to input '%s' of '%s' Op." %
+ (input_name, op_type_name))
+ attrs[input_arg.type_attr] = base_types[0]
+ inferred_from[input_arg.type_attr] = input_name
+ type_attr = _Attr(op_def, input_arg.type_attr)
+ _SatisfiesTypeConstraint(base_types[0], type_attr)
+ elif input_arg.type_attr:
+ # <type-attr>
+ attr_value = base_types[0]
+ if input_arg.type_attr in attrs:
+ if attrs[input_arg.type_attr] != attr_value:
+ assert False, "Unreachable"
+ else:
+ for base_type in base_types:
+ _SatisfiesTypeConstraint(base_type,
+ _Attr(op_def, input_arg.type_attr))
+ attrs[input_arg.type_attr] = attr_value
+ inferred_from[input_arg.type_attr] = input_name
+ elif input_arg.type_list_attr:
+ # <type-list-attr>
+ attr_value = base_types
+ if input_arg.type_list_attr in attrs:
+ if attrs[input_arg.type_list_attr] != attr_value:
+ raise TypeError(
+ "Input '%s' of '%s' Op has type list of %s that does not "
+ "match type list %s of argument '%s'." %
+ (input_name, op_type_name,
+ ", ".join(types_lib.as_dtype(x).name for x in attr_value),
+ ", ".join(types_lib.as_dtype(x).name
+ for x in attrs[input_arg.type_list_attr]),
+ inferred_from[input_arg.type_list_attr]))
+ else:
+ for base_type in base_types:
+ _SatisfiesTypeConstraint(base_type,
+ _Attr(op_def, input_arg.type_list_attr))
+ attrs[input_arg.type_list_attr] = attr_value
+ inferred_from[input_arg.type_list_attr] = input_name
+ else:
+ # single Tensor with specified type
+ if base_types[0] != input_arg.type:
+ assert False, "Unreachable"
+
+ if input_arg.is_ref:
+ if not all(x.is_ref_dtype for x in types):
+ raise TypeError(
+ "Input '%s' of '%s' Op requires l-value input" %
+ (input_name, op_type_name))
+ input_types.extend(types)
+ else:
+ input_types.extend(base_types)
+
+ # Process remaining attrs
+ for attr in op_def.attr:
+ # Skip attrs that have already had their values inferred
+ if attr.name in attrs:
+ if attr.name in keywords:
+ raise TypeError(
+ "Should not specify value for inferred attr '%s'." % attr.name)
+ continue
+ if attr.name in keywords:
+ attrs[attr.name] = keywords.pop(attr.name)
+ elif attr.name + "_" in keywords:
+ # Attrs whose names match Python keywords have an extra '_'
+ # appended, so we must check for that as well.
+ attrs[attr.name] = keywords.pop(attr.name + "_")
+ else:
+ raise TypeError("No argument for attr " + attr.name)
+
+ # Convert attr values to AttrValue protos.
+ attr_protos = {}
+ for attr_def in op_def.attr:
+ key = attr_def.name
+ value = attrs[key]
+ attr_value = attr_value_pb2.AttrValue()
+ if attr_def.HasField("default_value") and value is None:
+ attr_value.CopyFrom(attr_def.default_value)
+ attr_protos[key] = attr_value
+ continue
+ if attr_def.type.startswith("list("):
+ if not _IsListValue(value):
+ raise TypeError("Expected list for attr " + key)
+ if attr_def.has_minimum:
+ if len(value) < attr_def.minimum:
+ raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
+ "less than minimum %d." %
+ (key, op_type_name, len(value),
+ attr_def.minimum))
+ if attr_def.type == "string":
+ attr_value.s = _MakeStr(value, key)
+ if attr_def.HasField("allowed_values"):
+ if attr_value.s not in attr_def.allowed_values.list.s:
+ raise ValueError(
+ "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
+ (key, op_type_name, attr_value.s,
+ '", "'.join(attr_def.allowed_values.list.s)))
+ elif attr_def.type == "list(string)":
+ attr_value.list.s.extend([_MakeStr(x, key) for x in value])
+ if attr_def.HasField("allowed_values"):
+ for x in attr_value.list.s:
+ if x not in attr_def.allowed_values.list.s:
+ raise ValueError(
+ "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
+ (key, op_type_name, x,
+ '", "'.join(attr_def.allowed_values.list.s)))
+ elif attr_def.type == "int":
+ attr_value.i = _MakeInt(value, key)
+ if attr_def.has_minimum:
+ if attr_value.i < attr_def.minimum:
+ raise ValueError(
+ "Attr '%s' of '%s' Op passed %d less than minimum %d." %
+ (key, op_type_name, attr_value.i, attr_def.minimum))
+ elif attr_def.type == "list(int)":
+ attr_value.list.i.extend([_MakeInt(x, key) for x in value])
+ elif attr_def.type == "float":
+ attr_value.f = _MakeFloat(value, key)
+ elif attr_def.type == "list(float)":
+ attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
+ elif attr_def.type == "bool":
+ attr_value.b = _MakeBool(value, key)
+ elif attr_def.type == "list(bool)":
+ attr_value.list.b.extend([_MakeBool(x, key) for x in value])
+ elif attr_def.type == "type":
+ attr_value.type = _MakeType(value, attr_def)
+ elif attr_def.type == "list(type)":
+ attr_value.list.type.extend(
+ [_MakeType(x, attr_def) for x in value])
+ elif attr_def.type == "shape":
+ attr_value.shape.CopyFrom(_MakeShape(value, key))
+ elif attr_def.type == "list(shape)":
+ attr_value.list.shape.extend(
+ [_MakeShape(x, key) for x in value])
+ elif attr_def.type == "tensor":
+ attr_value.tensor.CopyFrom(_MakeTensor(value, key))
+ elif attr_def.type == "list(tensor)":
+ attr_value.list.tensor.extend(
+ [_MakeTensor(x, key) for x in value])
+ else:
+ raise TypeError("Unrecognized Attr type " + attr_def.type)
+
+ attr_protos[key] = attr_value
+ del attrs # attrs is no longer authoritative, use attr_protos instead
+
+ # Determine output types (possibly using attrs)
+ output_types = []
+ output_structure = []
+ for arg in op_def.output_arg:
+ types = []
+ if arg.number_attr:
+ n = _AttrValue(attr_protos, arg.number_attr).i
+ if arg.type_attr:
+ types = [_AttrValue(attr_protos, arg.type_attr).type] * n
+ else:
+ types = [arg.type] * n
+ output_structure.append(n)
+ elif arg.type_attr:
+ t = _AttrValue(attr_protos, arg.type_attr)
+ types = [t.type]
+ output_structure.append(None)
+ elif arg.type_list_attr:
+ t = _AttrValue(attr_protos, arg.type_list_attr)
+ types = t.list.type
+ output_structure.append(len(t.list.type))
+ else:
+ types = [arg.type]
+ output_structure.append(None)
+ if arg.is_ref:
+ types = [types_lib.as_dtype(x).as_ref for x in types]
+ output_types.extend(types)
+
+ if keywords:
+ raise TypeError("apply_op() got unexpected keyword arguments: " +
+ ", ".join(sorted(keywords.keys())))
+
+ # Add Op to graph
+ if output_structure:
+ op = g.create_op(op_type_name, inputs, output_types, name=scope,
+ input_types=input_types, attrs=attr_protos,
+ op_def=op_def)
+ outputs = op.outputs
+ return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs),
+ output_structure)
+ else:
+ return g.create_op(op_type_name, inputs, output_types, name=scope,
+ input_types=input_types, attrs=attr_protos,
+ op_def=op_def)