aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/op_def_library.py
blob: 946d543b44a879833ea097eea4b33af757fc3db9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Class to hold a library of OpDefs and use it to create Brain operations."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six

from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import constant_op
from tensorflow.python.platform import logging
from tensorflow.python.util import compat


def _Attr(op_def, name):
  for attr in op_def.attr:
    if attr.name == name:
      return attr
  raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" %
                  (op_def.name, name))


def _AttrValue(attr_protos, name):
  if name in attr_protos:
    return attr_protos[name]
  raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." %
                  (name, attr_protos))


def _SatisfiesTypeConstraint(dtype, attr_def):
  if attr_def.HasField("allowed_values"):
    allowed_list = attr_def.allowed_values.list.type
    if dtype not in allowed_list:
      raise TypeError(
          "DataType %s for attr '%s' not in list of allowed values: %s" %
          (dtypes.as_dtype(dtype).name, attr_def.name,
           ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))


def _IsListParameter(arg):
  if arg.number_attr:
    return True
  elif arg.type_list_attr:
    return True
  return False


def _NumTypeFields(arg):
  num = 0
  if arg.type != types_pb2.DT_INVALID: num += 1
  if arg.type_attr: num += 1
  if arg.type_list_attr: num += 1
  return num


def _IsListValue(v):
  return isinstance(v, (list, tuple))


def _Flatten(l):
  """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""
  # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]
  l_of_l = [x if _IsListValue(x) else [x] for x in l]
  # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]
  return [item for sublist in l_of_l for item in sublist]


def _Restructure(l, structure):
  """Returns the elements of list l structured according to the given structure.

  A structure is represented by a list whose elements are either
  `None` or a non-negative integer. `None` corresponds to a single
  element in the output list, and an integer N corresponds to a nested
  list of length N.

  The function returns a data structure whose shape is given by
  `structure`, and whose elements are taken from `l`. If `structure`
  is a singleton, the function returns the single data structure
  implied by the 0th element of `structure`. For example:

      _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])
        -> ["foo", ["bar", "baz"], "qux"]

      _Restructure(["foo"], [None]) -> "foo"

      _Restructure(["foo"], [1]) -> ["foo"]

      _Restructure([], [0]) -> []

  Args:
    l: A list.
    structure: A list whose elements are either `None` or a non-negative
      integer.

  Returns:
    The elements of `l`, restructured according to `structure`. If
    `structure` is a list of length 1, this function returns the
    single data structure implied by `structure[0]`.

  """
  result = []
  current_index = 0
  for element in structure:
    if element is None:
      result.append(l[current_index])
      current_index += 1
    else:
      result.append(l[current_index:current_index+element])
      current_index += element

  if len(result) == 1:
    return result[0]
  else:
    return tuple(result)


def _MakeFloat(v, arg_name):
  if not isinstance(v, compat.real_types):
    raise TypeError("Expected float for argument '%s' not %s." %
                    (arg_name, repr(v)))
  return float(v)


def _MakeInt(v, arg_name):
  if isinstance(v, six.string_types):
    raise TypeError("Expected int for argument '%s' not %s." %
                    (arg_name, repr(v)))
  try:
    return int(v)
  except (ValueError, TypeError):
    raise TypeError("Expected int for argument '%s' not %s." %
                    (arg_name, repr(v)))


def _MakeStr(v, arg_name):
  if not isinstance(v, compat.bytes_or_text_types):
    raise TypeError("Expected string for argument '%s' not %s." %
                    (arg_name, repr(v)))
  return compat.as_bytes(v)  # Convert unicode strings to bytes.


def _MakeBool(v, arg_name):
  if not isinstance(v, bool):
    raise TypeError("Expected bool for argument '%s' not %s." %
                    (arg_name, repr(v)))
  return v


def _MakeType(v, attr_def):
  try:
    v = dtypes.as_dtype(v)
  except TypeError:
    raise TypeError("Expected DataType for argument '%s' not %s." %
                    (attr_def.name, repr(v)))
  i = v.as_datatype_enum
  _SatisfiesTypeConstraint(i, attr_def)
  return i


def _MakeShape(v, arg_name):
  """Convert v into a TensorShapeProto."""
  # Args:
  #   v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
  #   arg_name: String, for error messages.

  # Returns:
  #   A TensorShapeProto.
  if isinstance(v, tensor_shape_pb2.TensorShapeProto):
    for d in v.dim:
      if d.name:
        logging.warning("Warning: TensorShapeProto with a named dimension: %s",
                        str(v))
        break
    return v
  s = tensor_shape.as_shape(v)
  ret = tensor_shape_pb2.TensorShapeProto()
  for i in s.as_dimension_list():
    ret.dim.add(size = i)
  return ret


def _MakeTensor(v, arg_name):
  """Ensure v is a TensorProto."""
  if isinstance(v, tensor_pb2.TensorProto):
    return v
  raise TypeError(
      "Don't know how to convert %s to a TensorProto for argument '%s'" %
      (repr(v), arg_name))


class _OpInfo(object):
  """All per-Op state we would like to precompute/validate."""

  def __init__(self, op_def):
    self.op_def = op_def
    # TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it
    # here, instead of these checks.
    for arg in list(op_def.input_arg) + list(op_def.output_arg):
      num_type_fields = _NumTypeFields(arg)
      if num_type_fields != 1:
        raise TypeError("Arg '%s' of '%s' must have one type field not %d" %
                        (arg.name, op_def.name, num_type_fields))
      if arg.type_attr:
        attr_type = _Attr(op_def, arg.type_attr).type
        if attr_type != "type":
          raise TypeError("Attr '%s' of '%s' used as a type_attr "
                          "but has type %s" %
                          (arg.type_attr, op_def.name, attr_type))
      if arg.type_list_attr:
        attr_type = _Attr(op_def, arg.type_list_attr).type
        if attr_type != "list(type)":
          raise TypeError(
              "Attr '%s' of '%s' used as a type_list_attr but has type %s" %
              (arg.type_attr, op_def.name, attr_type))
      if arg.number_attr:
        attr_type = _Attr(op_def, arg.number_attr).type
        if attr_type != "int":
          raise TypeError(
              "Attr '%s' of '%s' used as a number_attr but has type %s" %
              (arg.number_attr, op_def.name, attr_type))


class OpDefLibrary(object):
  """Holds a collection of OpDefs, can add the corresponding Ops to a graph."""

  def __init__(self):
    self._ops = {}

  def add_op(self, op_def):
    """Register an OpDef. May call apply_op with the name afterwards."""
    if not isinstance(op_def, op_def_pb2.OpDef):
      raise TypeError("%s is %s, not an op_def_pb2.OpDef" %
                      (op_def, type(op_def)))
    if not op_def.name:
      raise ValueError("%s missing name." % op_def)
    if op_def.name in self._ops:
      raise RuntimeError("Op name %s registered twice." % op_def.name)
    self._ops[op_def.name] = _OpInfo(op_def)

  def add_op_list(self, op_list):
    """Register the OpDefs from an OpList."""
    if not isinstance(op_list, op_def_pb2.OpList):
      raise TypeError("%s is %s, not an op_def_pb2.OpList" %
                      (op_list, type(op_list)))
    for op_def in op_list.op:
      self.add_op(op_def)

  def apply_op(self, op_type_name, g=None, name=None, **keywords):
    # pylint: disable=g-doc-args
    """Add a node invoking a registered Op to a graph.

    Config proto extensions must be provided via the 'ext' keyword argument.
    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # If none of the inputs are Tensors and your session doesn't have a
       # default graph, you will have to specify the graph.
       op_def_library.apply_op("op", input1=input1, g=g)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

    Args:
      op_type_name: string. Must match the name field of a registered Op.
      g: The graph context (optional)
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

    Returns:
      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

    Raises:
      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    """
    op_info = self._ops.get(op_type_name, None)
    if op_info is None:
      raise RuntimeError("Unrecognized Op name " + op_type_name)
    op_def = op_info.op_def

    # Determine the graph context.
    try:
      # Need to flatten all the arguments into a list.
      # pylint: disable=protected-access
      g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g)
      # pyline: enable=protected-access
    except AssertionError as e:
      raise RuntimeError(
          "Need to specify g=graph to Op '%s' (could not determine graph due "
          "to: %s)" % (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
      name = op_type_name

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

      # Perform input type inference
      inferred_from = {}
      for input_arg in op_def.input_arg:
        input_name = input_arg.name
        if input_name in keywords:
          values = keywords.pop(input_name)
        elif input_name + "_" in keywords:
          # Handle the case where the name is a keyword or built-in
          # for Python so we use the name + _ instead.
          input_name += "_"
          values = keywords.pop(input_name)
        else:
          raise TypeError("No argument for input " + input_name)

        # Goals:
        # * Convert values to Tensors if it contains constants.
        # * Verify that values is a list if that matches the input_arg's
        #   type.
        # * If the input_arg's type is determined by attrs, either set
        #   those attrs and validate those attr values are legal (if
        #   they have not yet been set) or validate the input matches
        #   the type indicated by the attrs (if they have already been
        #   inferred via an earlier input).
        # * If the input_arg has an explicit type, make sure the input
        #   conforms.

        if _IsListParameter(input_arg):
          if not _IsListValue(values):
            raise TypeError(
                "Expected list for '%s' argument to '%s' Op, not %s." %
                (input_name, op_type_name, values))
          # In cases where we expect all elements of the list to have the
          # same dtype, try to cast non-Tensor elements to that type.
          dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.number_attr:
            if input_arg.type_attr in attrs:
              dtype = attrs[input_arg.type_attr]
            else:
              for t in values:
                if isinstance(t, ops.Tensor):
                  dtype = t.dtype
                  break

          try:
            if not input_arg.is_ref and dtype:
              dtype = dtypes.as_dtype(dtype).base_dtype
            values = ops.convert_n_to_tensor_or_indexed_slices(
                values, name=input_arg.name,
                dtype=dtype if dtype else None,
                as_ref=input_arg.is_ref)
          except (TypeError, ValueError):
            assert dtype is not None, "Should not fail if dtype is None"
            assert input_arg.number_attr, "Should be number_attr case"
            # What types does the conversion function think values have?
            values = ops.convert_n_to_tensor_or_indexed_slices(
                values, as_ref=input_arg.is_ref)
            observed = ", ".join(v.dtype.base_dtype.name for v in values)

            prefix = (
                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s that do not match expected type %s." %
                              (prefix, dtype.name))
            elif input_arg.type_attr in attrs:
              raise TypeError("%s that do not match type %s inferred from "
                              "earlier arguments." %
                              (prefix, dtype.name))
            else:
              raise TypeError("%s that don't all match." % prefix)

          types = [x.dtype for x in values]
          inputs.extend(values)
        else:
          # In cases where we have an expected type, try to convert non-Tensor
          # arguments to that type.
          dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]
          try:
            values = ops.convert_to_tensor(
                values, name=input_arg.name, dtype=dtype,
                as_ref=input_arg.is_ref)
          except ValueError:
            # What type does convert_to_tensor think it has?
            observed = ops.convert_to_tensor(values,
                                             as_ref=input_arg.is_ref).dtype.name
            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                      (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s expected type of %s." %
                              (prefix, dtypes.as_dtype(input_arg.type).name))
            else:
              raise TypeError(
                  "%s type %s of argument '%s'." %
                  (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
                   inferred_from[input_arg.type_attr]))

          types = [values.dtype]
          inputs.append(values)
        base_types = [x.base_dtype for x in types]

        if input_arg.number_attr:
          # <number-attr> * <type> or <number-attr> * <type-attr>
          if input_arg.number_attr in attrs:
            if len(values) != attrs[input_arg.number_attr]:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d must match "
                  "length %d of argument '%s'." %
                  (input_name, op_type_name, len(values),
                   attrs[input_arg.number_attr],
                   inferred_from[input_arg.number_attr]))
          else:
            attrs[input_arg.number_attr] = len(values)
            inferred_from[input_arg.number_attr] = input_name
            num_attr = _Attr(op_def, input_arg.number_attr)
            if num_attr.has_minimum and len(values) < num_attr.minimum:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d shorter "
                  "than minimum length %d." %
                  (input_name, op_type_name, len(values), num_attr.minimum))
          # All tensors must have the same base type.
          if any([bt != base_types[0] for bt in base_types]):
            raise TypeError(
                "All tensors passed to '%s' of '%s' Op "
                "must have the same type." %
                (input_name, op_type_name))
          if input_arg.type != types_pb2.DT_INVALID:
            # <number-attr> * <type> case
            if base_types and base_types[0] != input_arg.type:
              assert False, "Unreachable"
          elif input_arg.type_attr in attrs:
            # <number-attr> * <type-attr> case, where <type-attr> already
            # has an inferred value.
            if base_types and base_types[0] != attrs[input_arg.type_attr]:
              assert False, "Unreachable"
          else:
            # <number-attr> * <type-attr> case, where we are now setting
            # the <type-attr> based on this input
            if not base_types:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  "list passed to input '%s' of '%s' Op." %
                  (input_name, op_type_name))
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr)
        elif input_arg.type_attr:
          # <type-attr>
          attr_value = base_types[0]
          if input_arg.type_attr in attrs:
            if attrs[input_arg.type_attr] != attr_value:
              assert False, "Unreachable"
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_attr))
            attrs[input_arg.type_attr] = attr_value
            inferred_from[input_arg.type_attr] = input_name
        elif input_arg.type_list_attr:
          # <type-list-attr>
          attr_value = base_types
          if input_arg.type_list_attr in attrs:
            if attrs[input_arg.type_list_attr] != attr_value:
              raise TypeError(
                  "Input '%s' of '%s' Op has type list of %s that does not "
                  "match type list %s of argument '%s'." %
                  (input_name, op_type_name,
                   ", ".join(dtypes.as_dtype(x).name for x in attr_value),
                   ", ".join(dtypes.as_dtype(x).name
                             for x in attrs[input_arg.type_list_attr]),
                   inferred_from[input_arg.type_list_attr]))
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_list_attr))
            attrs[input_arg.type_list_attr] = attr_value
            inferred_from[input_arg.type_list_attr] = input_name
        else:
          # single Tensor with specified type
          if base_types[0] != input_arg.type:
            assert False, "Unreachable"

        if input_arg.is_ref:
          if not all(x.is_ref_dtype for x in types):
            raise TypeError(
                "Input '%s' of '%s' Op requires l-value input" %
                (input_name, op_type_name))
          input_types.extend(types)
        else:
          input_types.extend(base_types)

      # Process remaining attrs
      for attr in op_def.attr:
        # Skip attrs that have already had their values inferred
        if attr.name in attrs:
          if attr.name in keywords:
            raise TypeError(
                "Should not specify value for inferred attr '%s'." % attr.name)
          continue
        if attr.name in keywords:
          attrs[attr.name] = keywords.pop(attr.name)
        elif attr.name + "_" in keywords:
          # Attrs whose names match Python keywords have an extra '_'
          # appended, so we must check for that as well.
          attrs[attr.name] = keywords.pop(attr.name + "_")
        else:
          raise TypeError("No argument for attr " + attr.name)

      # Convert attr values to AttrValue protos.
      attr_protos = {}
      for attr_def in op_def.attr:
        key = attr_def.name
        value = attrs[key]
        attr_value = attr_value_pb2.AttrValue()
        if attr_def.HasField("default_value") and value is None:
          attr_value.CopyFrom(attr_def.default_value)
          attr_protos[key] = attr_value
          continue
        if attr_def.type.startswith("list("):
          if not _IsListValue(value):
            raise TypeError("Expected list for attr " + key)
          if attr_def.has_minimum:
            if len(value) < attr_def.minimum:
              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
                               "less than minimum %d." %
                               (key, op_type_name, len(value),
                                attr_def.minimum))
        if attr_def.type == "string":
          attr_value.s = _MakeStr(value, key)
          if attr_def.HasField("allowed_values"):
            if attr_value.s not in attr_def.allowed_values.list.s:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                  (key, op_type_name, compat.as_text(attr_value.s),
                   '", "'.join(map(compat.as_text,
                                   attr_def.allowed_values.list.s))))
        elif attr_def.type == "list(string)":
          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
          if attr_def.HasField("allowed_values"):
            for x in attr_value.list.s:
              if x not in attr_def.allowed_values.list.s:
                raise ValueError(
                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                    (key, op_type_name, compat.as_text(x),
                     '", "'.join(map(compat.as_text,
                                     attr_def.allowed_values.list.s))))
        elif attr_def.type == "int":
          attr_value.i = _MakeInt(value, key)
          if attr_def.has_minimum:
            if attr_value.i < attr_def.minimum:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
                  (key, op_type_name, attr_value.i, attr_def.minimum))
        elif attr_def.type == "list(int)":
          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
        elif attr_def.type == "float":
          attr_value.f = _MakeFloat(value, key)
        elif attr_def.type == "list(float)":
          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
        elif attr_def.type == "bool":
          attr_value.b = _MakeBool(value, key)
        elif attr_def.type == "list(bool)":
          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
        elif attr_def.type == "type":
          attr_value.type = _MakeType(value, attr_def)
        elif attr_def.type == "list(type)":
          attr_value.list.type.extend(
              [_MakeType(x, attr_def) for x in value])
        elif attr_def.type == "shape":
          attr_value.shape.CopyFrom(_MakeShape(value, key))
        elif attr_def.type == "list(shape)":
          attr_value.list.shape.extend(
              [_MakeShape(x, key) for x in value])
        elif attr_def.type == "tensor":
          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
        elif attr_def.type == "list(tensor)":
          attr_value.list.tensor.extend(
              [_MakeTensor(x, key) for x in value])
        elif attr_def.type == "func":
          if not isinstance(value, compat.bytes_or_text_types):
            raise TypeError("Expects a string for the func name")
          attr_value.func.name = value
        else:
          raise TypeError("Unrecognized Attr type " + attr_def.type)

        attr_protos[key] = attr_value
      del attrs  # attrs is no longer authoritative, use attr_protos instead

      # Determine output types (possibly using attrs)
      output_types = []
      output_structure = []
      for arg in op_def.output_arg:
        types = []
        if arg.number_attr:
          n = _AttrValue(attr_protos, arg.number_attr).i
          if arg.type_attr:
            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
          else:
            types = [arg.type] * n
          output_structure.append(n)
        elif arg.type_attr:
          t = _AttrValue(attr_protos, arg.type_attr)
          types = [t.type]
          output_structure.append(None)
        elif arg.type_list_attr:
          t = _AttrValue(attr_protos, arg.type_list_attr)
          types = t.list.type
          output_structure.append(len(t.list.type))
        else:
          types = [arg.type]
          output_structure.append(None)
        if arg.is_ref:
          types = [dtypes.as_dtype(x).as_ref for x in types]
        output_types.extend(types)

      if keywords:
        raise TypeError("apply_op() got unexpected keyword arguments: " +
                        ", ".join(sorted(keywords.keys())))

      # Add Op to graph
      if output_structure:
        op = g.create_op(op_type_name, inputs, output_types, name=scope,
                         input_types=input_types, attrs=attr_protos,
                         op_def=op_def)
        outputs = op.outputs
        return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs),
                            output_structure)
      else:
        return g.create_op(op_type_name, inputs, output_types, name=scope,
                           input_types=input_types, attrs=attr_protos,
                           op_def=op_def)