aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/common_shapes.py
blob: 40788e24c486c4357042672e3697063a4c7fb381 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A library of common shape functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import six.moves

from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util


def has_fully_defined_shape(tensor):
  """Returns true if tensor has a fully defined shape."""
  return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined()


def rank(tensor):
  """Return a rank if it is a tensor, else return None."""
  if isinstance(tensor, ops.Tensor):
    return tensor._rank()  # pylint: disable=protected-access
  return None


def scalar_shape(unused_op):
  """Shape function for ops that output a scalar value."""
  return [tensor_shape.scalar()]


def unchanged_shape(op):
  """Shape function for ops that output a tensor like their first input."""
  return [op.inputs[0].get_shape()]


def unchanged_shape_with_rank(rank):
  """Returns a shape function for ops that constrain the rank of their input.

  Args:
    rank: The exact rank of the input and output.

  Returns:
    A shape function for ops that output a tensor of the same size as their
    input, with a particular rank.
  """

  def _ShapeFunction(op):
    return [op.inputs[0].get_shape().with_rank(rank)]

  return _ShapeFunction


def unchanged_shape_with_rank_at_least(rank):
  """Returns a shape function for ops that constrain the rank of their input.

  Args:
    rank: A lower bound on the rank of the input and output.

  Returns:
    A shape function for ops that output a tensor of the same size as their
    input, with a particular rank.
  """

  def _ShapeFunction(op):
    return [op.inputs[0].get_shape().with_rank_at_least(rank)]

  return _ShapeFunction


def unchanged_shape_with_rank_at_most(rank):
  """Returns a shape function for ops that constrain the rank of their input.

  Args:
    rank: An upper bound on the rank of the input and output.

  Returns:
    A shape function for ops that output a tensor of the same size as their
    input, with a particular rank.
  """

  def _ShapeFunction(op):
    return [op.inputs[0].get_shape().with_rank_at_most(rank)]

  return _ShapeFunction


def matmul_shape(op):
  """Shape function for a MatMul op."""
  a_shape = op.inputs[0].get_shape().with_rank(2)
  transpose_a = op.get_attr("transpose_a")
  b_shape = op.inputs[1].get_shape().with_rank(2)
  transpose_b = op.get_attr("transpose_b")
  output_rows = a_shape[1] if transpose_a else a_shape[0]
  output_cols = b_shape[0] if transpose_b else b_shape[1]
  inner_a = a_shape[0] if transpose_a else a_shape[1]
  inner_b = b_shape[1] if transpose_b else b_shape[0]
  inner_a.assert_is_compatible_with(inner_b)
  return [tensor_shape.TensorShape([output_rows, output_cols])]


def get_conv_output_size(input_size, filter_size, strides, padding_type):
  """Returns the spatial size of a n-d convolution/pooling output."""
  input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
  filter_size = tuple([tensor_shape.as_dimension(x).value for x in filter_size])
  strides = [int(x) for x in strides]

  if all(x == 1 for x in input_size) and all(x == 1 for x in filter_size):
    return input_size

  if any(x is not None and y is not None and x > y for x, y in
         zip(filter_size, input_size)):
    raise ValueError("Filter must not be larger than the input: "
                     "Filter: %r Input: %r" % (filter_size, input_size))

  if padding_type == b"VALID":

    def _valid(in_dim, k_dim, s_dim):
      if in_dim is not None and k_dim is not None:
        return (in_dim - k_dim + s_dim) // s_dim
      else:
        return None

    output_size = [
        _valid(in_dim, k_dim, s_dim)
        for in_dim, k_dim, s_dim in zip(input_size, filter_size, strides)
    ]
  elif padding_type == b"SAME":

    def _same(in_dim, s_dim):
      if in_dim is not None:
        return (in_dim + s_dim - 1) // s_dim
      else:
        return None

    output_size = [_same(in_dim, s_dim)
                   for in_dim, s_dim in zip(input_size, strides)]
  else:
    raise ValueError("Invalid padding: %r" % padding_type)

  return tuple(output_size)


def get2d_conv_output_size(input_height, input_width, filter_height,
                           filter_width, row_stride, col_stride, padding_type):
  """Returns the number of rows and columns in a convolution/pooling output."""
  return get_conv_output_size((input_height, input_width),
                              (filter_height, filter_width),
                              (row_stride, col_stride), padding_type)


def conv2d_shape(op):
  """Shape function for a Conv2D op.

  This op has two inputs:

  * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
  * filter, a 4D tensor with shape =  [filter_rows, filter_cols,
    depth_in, depth_out]

  The output is a 4D tensor with shape = [batch_size, out_rows,
  out_cols, depth_out], where out_rows and out_cols depend on the
  value of the op's "padding" and "strides" attrs.

  Args:
    op: A Conv2D Operation.

  Returns:
    A list containing the Shape of the Conv2D output.

  Raises:
    ValueError: If the shapes of the input or filter are incompatible.
  """
  input_shape = op.inputs[0].get_shape().with_rank(4)
  filter_shape = op.inputs[1].get_shape().with_rank(4)

  try:
    data_format = op.get_attr("data_format")
  except ValueError:
    data_format = None

  if data_format == b"NCHW":
    # Convert input shape to the default NHWC for inference.
    input_shape = [input_shape[0], input_shape[2], input_shape[3],
                   input_shape[1]]

  batch_size = input_shape[0]
  in_rows = input_shape[1]
  in_cols = input_shape[2]

  filter_rows = filter_shape[0]
  filter_cols = filter_shape[1]
  depth_out = filter_shape[3]
  # Check that the input depths are compatible.
  input_shape[3].assert_is_compatible_with(filter_shape[2])

  if data_format == b"NCHW":
    stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
  else:
    stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")

  if stride_b != 1 or stride_d != 1:
    raise ValueError("Current implementation does not yet support "
                     "strides in the batch and depth dimensions.")
  # TODO(mrry,shlens): Raise an error if the stride would cause
  # information in the input to be ignored. This will require a change
  # in the kernel implementation.
  padding = op.get_attr("padding")
  out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
                                              filter_cols, stride_r, stride_c,
                                              padding)

  output_shape = [batch_size, out_rows, out_cols, depth_out]
  if data_format == b"NCHW":
    # Convert output shape back to NCHW.
    output_shape = [output_shape[0], output_shape[3], output_shape[1],
                    output_shape[2]]
  return [tensor_shape.TensorShape(output_shape)]


def depthwise_conv2d_native_shape(op):
  """Shape function for a DepthwiseConv2D op.

  This op has two inputs:

  * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
  * filter, a 4D tensor with shape =  [filter_rows, filter_cols,
    depth_in, depthwise_multiplier]

  The output is a 4D tensor with shape = [batch_size, out_rows,
  out_cols, depth_in*depthwise_multiplier], where out_rows and out_cols depend
  on the value of the op's "padding" and "strides" attrs.

  Args:
    op: A DepthwiseConv2dNative Operation.

  Returns:
    A list containing the Shape of the DepthwiseConv2DNative output.

  Raises:
    ValueError: If the shapes of the input or filter are incompatible.
  """
  input_shape = op.inputs[0].get_shape().with_rank(4)
  filter_shape = op.inputs[1].get_shape().with_rank(4)

  batch_size = input_shape[0]
  in_rows = input_shape[1]
  in_cols = input_shape[2]

  filter_rows = filter_shape[0]
  filter_cols = filter_shape[1]
  depth_out = filter_shape[3] * filter_shape[2]
  # Check that the input depths are compatible.
  input_shape[3].assert_is_compatible_with(filter_shape[2])

  stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
  if stride_b != 1 or stride_d != 1:
    raise ValueError("Current implementation does not yet support "
                     "strides in the batch and depth dimensions.")
  if stride_r != stride_c:
    # TODO(shlens): Add support for this.
    raise ValueError("Current implementation only supports equal length "
                     "strides in the row and column dimensions.")

  # TODO(mrry,shlens): Raise an error if the stride would cause
  # information in the input to be ignored. This will require a change
  # in the kernel implementation.
  stride = stride_r
  padding = op.get_attr("padding")
  out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
                                              filter_cols, stride, stride,
                                              padding)

  return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]


def separable_conv2d_shape(op):
  """Shape function for a SeparableConv2D op.

  This op has three inputs:

  * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]

  * depthwise_filter, a 4D tensor with shape = [filter_rows,
    filter_cols, depth_in, depth_multiplier]

  * pointwise_filter, a 4D tensor with shape = [1, 1, depth_in *
    depth_multiplier, depth_out]

  The output is a 4D tensor with shape = [batch_size, out_rows,
  out_cols, depth_out], where out_rows and out_cols depend on the
  value of the op's "padding" and "strides" attrs.

  Args:
    op: A SeparableConv2D Operation.

  Returns:
    A list containing the Shape of the SeparableConv2D output.

  Raises:
    ValueError: If the shapes of the input or filter are incompatible.
  """
  input_shape = op.inputs[0].get_shape().with_rank(4)
  depthwise_filter_shape = op.inputs[1].get_shape().merge_with(
      tensor_shape.TensorShape([None, None, input_shape[3], None]))
  pointwise_depth_in = depthwise_filter_shape[2] * depthwise_filter_shape[3]

  pointwise_filter_shape = op.inputs[2].get_shape().merge_with(
      tensor_shape.TensorShape([1, 1, pointwise_depth_in, None]))

  batch_size = input_shape[0]
  in_rows = input_shape[1]
  in_cols = input_shape[2]

  filter_rows = depthwise_filter_shape[0]
  filter_cols = depthwise_filter_shape[1]
  depth_out = pointwise_filter_shape[3]

  stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
  if stride_b != 1 or stride_d != 1:
    raise ValueError("Current implementation does not yet support "
                     "strides in the batch and depth dimensions.")
  if stride_r != stride_c:
    # TODO(shlens): Add support for this.
    raise ValueError("Current implementation only supports equal length "
                     "strides in the row and column dimensions.")

  # TODO(mrry,shlens): Raise an error if the stride would cause
  # information in the input to be ignored. This will require a change
  # in the kernel implementation.
  stride = stride_r
  padding = op.get_attr("padding")
  out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
                                              filter_cols, stride, stride,
                                              padding)

  return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]


def avg_pool_shape(op):
  """Shape function for an AvgPool op.

  This op has one input:

  * input, a 4D tensor with shape = [batch_size, rows, cols, depth]

  The output is a 4D tensor with shape = [batch_size, out_rows,
  out_cols, depth_out], where out_rows and out_cols depend on the
  value of the op's "ksize", "strides", and "padding" attrs.

  Args:
    op: An AvgPool Operation.

  Returns:
    A single-element list containing the Shape of the AvgPool output.

  Raises:
    ValueError: If the shape of the input is invalid or incompatible with
      the values of the attrs.
  """
  input_shape = op.inputs[0].get_shape().with_rank(4)
  try:
    data_format = op.get_attr("data_format")
  except ValueError:
    data_format = None

  if data_format == b"NCHW":
    # Convert input shape to the default NHWC for inference.
    input_shape = [input_shape[0], input_shape[2], input_shape[3],
                   input_shape[1]]

  if data_format == b"NCHW":
    ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
    stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
  else:
    ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
    stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")

  batch_size = input_shape[0]
  in_rows = input_shape[1]
  in_cols = input_shape[2]
  depth = input_shape[3]

  if ksize_b != 1 or ksize_d != 1:
    raise ValueError("Current implementation does not support pooling "
                     "in the batch and depth dimensions.")
  if stride_b != 1 or stride_d != 1:
    raise ValueError("Current implementation does not support strides "
                     "in the batch and depth dimensions.")

  # TODO(mrry,shlens): Raise an error if the stride would cause
  # information in the input to be ignored. This will require a change
  # in the kernel implementation.
  padding = op.get_attr("padding")

  out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
                                              ksize_c, stride_r, stride_c,
                                              padding)

  output_shape = [batch_size, out_rows, out_cols, depth]
  if data_format == b"NCHW":
    # Convert output shape back to NCHW.
    output_shape = [output_shape[0], output_shape[3], output_shape[1],
                    output_shape[2]]
  return [tensor_shape.TensorShape(output_shape)]


def max_pool_shape(op):
  """Shape function for a MaxPool op.

  This op has one input:

  * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]

  The output is a 4D tensor with shape = [batch_size, out_rows,
  out_cols, depth_out], where out_rows, out_cols, and depth_out depend
  on the value of the op's "ksize", "strides", and "padding" attrs.

  Args:
    op: A MaxPool Operation.

  Returns:
    A single-element list containing the Shape of the MaxPool output.

  Raises:
    ValueError: If the shape of the input is invalid or incompatible with
      the values of the attrs.
  """
  input_shape = op.inputs[0].get_shape().with_rank(4)
  try:
    data_format = op.get_attr("data_format")
  except ValueError:
    data_format = None

  if data_format == b"NCHW":
    # Convert input shape to the default NHWC for inference.
    input_shape = [input_shape[0], input_shape[2], input_shape[3],
                   input_shape[1]]

  if data_format == b"NCHW":
    ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
    stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
  else:
    ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
    stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")

  batch_size = input_shape[0]
  in_rows = input_shape[1]
  in_cols = input_shape[2]
  depth = input_shape[3]

  if ksize_b != 1:
    raise ValueError("Current implementation does not support pooling "
                     "in the batch dimension.")
  if stride_b != 1:
    raise ValueError("Current implementation does not support strides "
                     "in the batch dimension.")

  if not ((ksize_r == 1 and ksize_c == 1) or ksize_d == 1):
    raise ValueError("MaxPooling supports exactly one of pooling across depth "
                     "or pooling across width/height.")

  # TODO(mrry,shlens): Raise an error if the stride would cause
  # information in the input to be ignored. This will require a change
  # in the kernel implementation.
  if ksize_d == 1:
    padding = op.get_attr("padding")
    out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
                                                ksize_c, stride_r, stride_c,
                                                padding)
    output_shape = [batch_size, out_rows, out_cols, depth]
  else:
    if depth % ksize_d > 0:
      raise ValueError("Depthwise max pooling requires the depth window "
                       "to evenly divide the input depth.")
    if stride_d != ksize_d:
      raise ValueError("Depthwise max pooling requires the depth window "
                       "to equal the depth stride.")
    output_shape = [batch_size, in_rows, in_cols, depth // ksize_d]

  if data_format == b"NCHW":
    # Convert output shape back to NCHW.
    output_shape = [output_shape[0], output_shape[3], output_shape[1],
                    output_shape[2]]
  return [tensor_shape.TensorShape(output_shape)]


def no_outputs(unused_op):
  """Shape function for use with ops that have no outputs."""
  return []


def unknown_shape(op):
  """Shape function for use with ops whose output shapes are unknown."""
  return [tensor_shape.unknown_shape() for _ in op.outputs]


def _broadcast_shape_helper(shape_x, shape_y):
  """Helper functions for is_broadcast_compatible and broadcast_shape.

  Args:
    shape_x: A `TensorShape`
    shape_y: A `TensorShape`

  Returns:
    Returns None if the shapes are not broadcast compatible,
    a list of the broadcast dimensions otherwise.
  """
  # To compute the broadcasted dimensions, we zip together shape_x and shape_y,
  # and pad with 1 to make them the same length.
  broadcasted_dims = reversed(list(six.moves.zip_longest(
      reversed(shape_x.dims),
      reversed(shape_y.dims),
      fillvalue=tensor_shape.Dimension(1))))
  # Next we combine the dimensions according to the numpy broadcasting rules.
  # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
  return_dims = []
  for (dim_x, dim_y) in broadcasted_dims:
    if dim_x.value is None or dim_y.value is None:
      # One or both dimensions is unknown. If either dimension is greater than
      # 1, we assume that the program is correct, and the other dimension will
      # be broadcast to match it.
      # TODO(mrry): If we eliminate the shape checks in C++, we must still
      # assert that the unknown dim is either 1 or the same as the known dim.
      if dim_x.value is not None and dim_x.value > 1:
        return_dims.append(dim_x)
      elif dim_y.value is not None and dim_y.value > 1:
        return_dims.append(dim_y)
      else:
        return_dims.append(None)
    elif dim_x.value == 1:
      # We will broadcast dim_x to dim_y.
      return_dims.append(dim_y)
    elif dim_y.value == 1:
      # We will broadcast dim_y to dim_x.
      return_dims.append(dim_x)
    elif dim_x.value == dim_y.value:
      # The dimensions are compatible, so output is the same size in that
      # dimension.
      return_dims.append(dim_x.merge_with(dim_y))
    else:
      return None
  return return_dims


def is_broadcast_compatible(shape_x, shape_y):
  """Returns True if `shape_x` and `shape_y` are broadcast compatible.

  Args:
    shape_x: A `TensorShape`
    shape_y: A `TensorShape`

  Returns:
    True if a shape exists that both `shape_x` and `shape_y` can be broadcasted
    to.  False otherwise.
  """
  if shape_x.ndims is None or shape_y.ndims is None:
    return False
  return _broadcast_shape_helper(shape_x, shape_y) is not None


def broadcast_shape(shape_x, shape_y):
  """Returns the broadcasted shape between `shape_x` and `shape_y`.

  Args:
    shape_x: A `TensorShape`
    shape_y: A `TensorShape`

  Returns:
    A `TensorShape` representing the broadcasted shape.

  Raises:
    ValueError: If the two shapes can not be broadcasted.
  """
  if shape_x.ndims is None or shape_y.ndims is None:
    return tensor_shape.unknown_shape()
  return_dims = _broadcast_shape_helper(shape_x, shape_y)
  if return_dims is None:
    raise ValueError("Incompatible shapes for broadcasting: %s and %s"
                     % (shape_x, shape_y))
  return tensor_shape.TensorShape(return_dims)


def call_cpp_shape_fn(op, require_shape_fn=True):
  """A shape function that delegates to the registered C++ shape function.

  Args:
    op: the node in the graph for which to compute output shapes.
    require_shape_fn: If true, and the C++ shape function is not registered
      in the current binary then an exception is raised; otherwise, if the
      C++ shape function is not registered then unknown_shape is used.

  Returns:
    A dictionary with the following keys:
      shapes: A TensorShape list of the output shapes of the op, as computed
        using the C++ shape inference function registered for the op.
      handle_shapes: A TensorShape list of the shapes for handle outputs, if
         any.
      handle_dtypes: A list of DataType enums for the handle outputs, if any.

  Raises:
    ValueError: If the C++ shape function returned an error (e.g. because the
      shapes of the inputs are of the wrong rank or otherwise incompatible
      according to the shape function).
    RuntimeError: If the C++ shape function is not registered and
      <require_shape_fn> is True.
  """
  if op.type == "Const":
    # To avoid serializing large constants, we special-case constant
    # here, even though it has a C++ shape function.  When Python
    # calls the C / C-API directly, we should be able to remove this.
    return {
        "shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)],
        "handle_data": [None]
    }

  input_tensors_needed = []
  input_tensors_as_shapes_needed = []

  while True:
    res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
                                  input_tensors_as_shapes_needed,
                                  require_shape_fn)
    if not isinstance(res, dict):
      # Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).
      return res

    # See if we need to evaluate some inputs.
    if not res["inputs_needed"]:
      return res
    p = cpp_shape_inference_pb2.CppShapeInferenceInputsNeeded()
    p = p.FromString(res["inputs_needed"])
    changed = False
    for idx in p.input_tensors_needed:
      if idx not in input_tensors_needed:
        input_tensors_needed.append(idx)
        changed = True
    for idx in p.input_tensors_as_shapes_needed:
      if idx not in input_tensors_as_shapes_needed:
        input_tensors_as_shapes_needed.append(idx)
        changed = True
    if not changed:
      return res


def _call_cpp_shape_fn_impl(
    op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn):
  """Core implementation of call_cpp_shape_fn."""
  graph_def_version = op.graph.graph_def_versions.producer
  node_def_str = op.node_def.SerializeToString()

  def tensor_to_inference_result(t):
    r = cpp_shape_inference_pb2.CppShapeInferenceResult()
    r.shape.CopyFrom(t.get_shape().as_proto())
    # pylint: disable=protected-access
    if t._handle_data is not None:
      r.handle_data.CopyFrom(t._handle_data)
    # pylint: enable=protected-access
    return r.SerializeToString()
  input_shapes = [tensor_to_inference_result(i) for i in op.inputs]

  input_tensors = [None for i in input_shapes]
  for idx in input_tensors_needed:
    v = tensor_util.constant_value(op.inputs[idx])
    if v is not None:
      input_tensors[idx] = np.asarray(v)

  serialized_unknown_shape = (
      tensor_shape.TensorShape(None).as_proto().SerializeToString())
  arr = [serialized_unknown_shape for i in input_shapes]
  for idx in input_tensors_as_shapes_needed:
    s = tensor_util.constant_value_as_shape(op.inputs[idx])
    if s is not None:
      arr[idx] = s.as_proto().SerializeToString()
  input_tensors_as_shapes = arr

  missing_shape_fn = False
  try:
    with errors.raise_exception_on_not_ok_status() as status:
      output = pywrap_tensorflow.RunCppShapeInference(
          graph_def_version, node_def_str, input_shapes, input_tensors,
          input_tensors_as_shapes, status)
  except errors.InvalidArgumentError as err:
    if err.message.startswith("No shape inference function exists for op"):
      missing_shape_fn = True
    else:
      raise ValueError(err.message)

  if missing_shape_fn:
    if require_shape_fn:
      raise RuntimeError(
          "No C++ shape function registered for standard op: %s" % op.type)
    return unknown_shape(op)

  output_shapes = output[:-1]

  # Convert TensorShapeProto values in output_shapes.
  result_protos = [
      cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
      for s in output_shapes
  ]
  result = [r.shape for r in result_protos]
  result_handle_data = [
      r.handle_data if r.handle_data.is_set else None for r in result_protos
  ]

  return {
      "shapes": result,
      "handle_data": result_handle_data,
      "inputs_needed": output[-1]
  }

# pylint: disable=protected-access
ops._set_call_cpp_shape_fn(call_cpp_shape_fn)
# pylint: enable=protected-access