aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.h
blob: accc587000767554f87a195e0ea33640cd696244 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
#define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_

#include <vector>

#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/macros.h"

namespace tensorflow {

class ShapeRefiner;
class ShapeRefinerTest;

namespace grappler {
class GraphProperties;
class SymbolicShapeManager;
}  // namespace grappler

namespace shape_inference {

struct DimensionOrConstant;
class InferenceContext;

// Dimension values are accessed through InferenceContext.
class Dimension {
 private:
  Dimension();
  Dimension(int64 value);
  ~Dimension() {}

  const int64 value_;

  friend class InferenceContext;
  friend class ShapeManager;
  TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
};

class DimensionHandle {
 public:
  DimensionHandle() {}
  bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
  std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }

 private:
  DimensionHandle(const Dimension* dim) { ptr_ = dim; }

  const Dimension* operator->() const { return ptr_; }
  bool IsSet() const { return ptr_ != nullptr; }

  const Dimension* ptr_ = nullptr;

  friend struct DimensionOrConstant;
  friend class InferenceContext;
  friend class ShapeInferenceTest;
  friend class ShapeInferenceTestutil;
  friend class ::tensorflow::ShapeRefinerTest;
  friend class ShapeManager;
  friend class ::tensorflow::grappler::GraphProperties;
  friend class ::tensorflow::grappler::SymbolicShapeManager;

  // Intentionally copyable.
};

// Shape rank and dimensions are accessed through InferenceContext.
class Shape {
 private:
  Shape();
  Shape(const std::vector<DimensionHandle>& dims);
  ~Shape() {}

  const int32 rank_;
  const std::vector<DimensionHandle> dims_;

  friend class InferenceContext;
  friend class ShapeManager;
  friend class ::tensorflow::grappler::SymbolicShapeManager;

  TF_DISALLOW_COPY_AND_ASSIGN(Shape);
};

class ShapeHandle {
 public:
  ShapeHandle() {}
  bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
  std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }

 private:
  ShapeHandle(const Shape* shape) { ptr_ = shape; }
  const Shape* operator->() const { return ptr_; }
  bool IsSet() const { return ptr_ != nullptr; }

  const Shape* ptr_ = nullptr;

  friend class InferenceContext;
  friend class ShapeInferenceTest;
  friend class ShapeInferenceTestutil;
  friend class ::tensorflow::ShapeRefinerTest;
  friend class ShapeManager;
  friend class ::tensorflow::grappler::SymbolicShapeManager;

  // Intentionally copyable.
};

// Struct used to allow functions to take DimensionHandle or a dimension value.
// Not meant to be constructed directly.
struct DimensionOrConstant {
 public:
  // Intentionally not explicit.
  DimensionOrConstant(DimensionHandle dim);

  // val must be non-negative or InferenceContext::kUnknownDim.
  DimensionOrConstant(int64 val);

  // dim takes precedence. If dim != nullptr, val is ignored.
  DimensionHandle dim;
  int64 val;

 private:
  DimensionOrConstant();
};

struct ShapeAndType {
  ShapeAndType() {}
  ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}

  ShapeHandle shape;
  DataType dtype = DT_INVALID;
};

// Shape inference functions registered on ops in REGISTER_OP implement
// their shape functions in terms of this InferenceContext.  An InferenceContext
// is created by the framework and passed to a shape inference function.  The
// shape inference function calls functions on the context, and should call
// set_output() to set the shape on all outputs.
//
// To infer shapes for user-defined functions see ShapeRefiner.
//
// All Shape* and Dimension* returned by functions of InferenceContext are owned
// by the InferenceContext.
class InferenceContext {
 public:
  static constexpr int64 kUnknownDim = -1;
  static constexpr int32 kUnknownRank = -1;

  // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
  //
  // Elements of <input_tensors_as_shapes> are used for when a shape function
  // makes a call to MakeShapeFromShapeTensor; in particular, when the
  // input_tensors[i] is nullptr but the shape represented by it is partially
  // known from analysis of the graph.
  // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
  // Values of <input_tensors_as_shapes> do not need to outlive the context.
  //
  // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
  InferenceContext(int graph_def_version, const NodeDef* node_def,
                   const OpDef& op_def,
                   const std::vector<ShapeHandle>& input_shapes,
                   const std::vector<const Tensor*>& input_tensors,
                   const std::vector<ShapeHandle>& input_tensors_as_shapes,
                   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
                       input_handle_shapes_and_types);

  // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
  //
  // Elements of <input_tensors_as_shapes> are used for when a shape
  // function makes a call to MakeShapeFromShapeTensor; in particular, when
  // the input_tensors[i] is nullptr but the shape represented by it is
  // partially known from analysis of the graph. <input_tensors_as_shapes>
  // can have fewer elements than <input_shapes>. Values of
  // <input_tensors_as_shapes> do not need to outlive the context.
  //
  // REQUIRES: <node_def> is not NULL, and must outlive the
  // InferenceContext.
  InferenceContext(
      int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
      const std::vector<TensorShapeProto>& input_shapes,
      const std::vector<const Tensor*>& input_tensors,
      const std::vector<TensorShapeProto>& input_tensors_as_shapes,
      const std::vector<
          std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
          input_handle_shapes_and_types);

  // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
  //
  // Elements of <input_tensors_as_shapes> are used for when a shape
  // function makes a call to MakeShapeFromShapeTensor; in particular, when
  // the input_tensors[i] is nullptr but the shape represented by it is
  // partially known from analysis of the graph. <input_tensors_as_shapes>
  // can have fewer elements than <input_shapes>. Values of
  // <input_tensors_as_shapes> do not need to outlive the context.
  //
  // REQUIRES: <node_def> is not NULL, and must outlive the
  // InferenceContext.
  InferenceContext(
      int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
      const std::vector<PartialTensorShape>& input_shapes,
      const std::vector<const Tensor*>& input_tensors,
      const std::vector<PartialTensorShape>& input_tensors_as_shapes,
      const std::vector<std::unique_ptr<
          std::vector<std::pair<PartialTensorShape, DataType>>>>&
          input_handle_shapes_and_types);

  ~InferenceContext();

  // Runs the shape inference function 'fn' with 'this' as the
  // argument, returns the status of the inference.
  //
  // On error, additional context is provided in the error message.
  Status Run(
      const std::function<Status(shape_inference::InferenceContext* c)>& fn);

  // Merge the stored shape of the input in position idx with <shape> according
  // to the following rules:
  //
  // - If the ShapeHandles are the same or <shape> is unknown, there will be no
  //   change. Otherwise if the stored shape is unknown, the new shape will be
  //   <shape>.
  // - If both shapes are known, then they must have the same rank.
  // - For any one dimension, if the values for that dimension in both shapes
  //   are known, then the values must match.
  // - If one shape has equal or more information than the other shape in every
  //   dimension, the new shape will become the shape with more information.
  // - Example: merging [2,?] and [?,2] results in [2,2]
  // - Example: [2,2] cannot be merged with [1,2]
  //
  // This requires idx to be in the [0, num_inputs) range. If the merge is
  // successful, return true. Return false otherwise.
  bool MergeInput(int idx, ShapeHandle shape) {
    ShapeHandle new_shape;
    if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
    inputs_[idx] = new_shape;
    return true;
  }

  // Relax the stored shape of the input in position idx with <shape> according
  // to the following rules:
  //
  // - If the ShapeHandles are the same then the stored shape will be returned.
  // - If either of the ShapeHandles are unknown, then a new UnknownShape will
  //   be returned. A new shape must be returned because we cannot claim that
  //   the resulting shape is necessarily the same as either of the input
  //   shapes.
  // - If the shapes both have known ranks but their ranks are different, a new
  //   UnknownShape will be returned.
  // - For any one dimension, if the value for that dimension in either of the
  //   shapes is unknown, a new shape will be returned with a new UnknownDim in
  //   that dimension.
  // - For any one dimension, if the values for that dimension in both shapes
  //   are known but do not match, a new shape will be returned with a new
  //   UnknownDim in that dimension.
  // - If both shapes have the same known rank and match in every dimension,
  //   the stored shape will be returned.
  // - Example: relaxing [2,?] and [?,2] results in [?,?]
  // - Example: relaxing [2,2] and [3,2] results in [?,2]
  // - Example: relaxing [2,2] with [1,2,3] results in ?
  //
  // This requires idx to be in the [0, num_inputs) range. If the relax is
  // successful and the new shape differs from the old one, store the new
  // shape and return true. Return false otherwise.
  bool RelaxInput(int idx, ShapeHandle shape) {
    ShapeHandle new_shape;
    Relax(inputs_[idx], shape, &new_shape);
    if (inputs_[idx].SameHandle(new_shape)) {
      return false;
    }
    inputs_[idx] = new_shape;
    return true;
  }

  ShapeHandle input(int64 idx) const { return inputs_[idx]; }
  Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
  int num_inputs() const { return inputs_.size(); }

  // Returns the input tensor at index <idx>, or nullptr if the input tensor is
  // not available at the time of shape inference.
  const Tensor* input_tensor(int idx) {
    // Mark that this idx was requested.
    requested_input_tensor_[idx] = true;
    return input_tensors_[idx];
  }

  // Returns true iff input_tensor(idx) was called by the shape function.
  bool requested_input_tensor(int idx) const {
    return requested_input_tensor_[idx];
  }

  // Returns true if MakeShapeFromInputTensor was called but the constant
  // input_tensor was not present.
  bool requested_input_tensor_as_partial_shape(int idx) const {
    return requested_input_tensor_as_partial_shape_[idx];
  }

  void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
    input_tensors_ = input_tensors;
  }

  void set_input_tensors_as_shapes(
      const std::vector<ShapeHandle>& input_tensors_as_shapes) {
    input_tensors_as_shapes_ = input_tensors_as_shapes;
  }

  ShapeHandle output(int64 idx) const { return outputs_[idx]; }
  void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
  Status set_output(StringPiece output_name,
                    const std::vector<ShapeHandle>& shapes);

  int num_outputs() const { return outputs_.size(); }
  ShapeHandle output(int idx) const { return outputs_[idx]; }
  Status output(StringPiece output_name,
                std::vector<ShapeHandle>* output) const;

  AttrSlice attrs() const { return AttrSlice(*node_def_); }

  string op() const;

  // idx can be negative for an offset from end of dimensions.
  // idx must be in the range [-1 * s.rank, s.rank).
  DimensionHandle Dim(ShapeHandle s, int64 idx) {
    if (s->rank_ == kUnknownRank) {
      return UnknownDim();
    }
    return DimKnownRank(s, idx);
  }
  // As above, but asserts that the rank of the shape is known.
  static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) {
    CHECK_NE(s->rank_, kUnknownRank);
    if (idx < 0) {
      return s->dims_[s->dims_.size() + idx];
    }
    return s->dims_[idx];
  }

  static int32 Rank(ShapeHandle s) {
    DCHECK(s.IsSet());
    return s.IsSet() ? s->rank_ : kUnknownRank;
  }
  static bool RankKnown(ShapeHandle s) {
    return (s.IsSet() && (Rank(s) != kUnknownRank));
  }
  static inline int64 Value(DimensionOrConstant d) {
    return d.dim.IsSet() ? d.dim->value_ : d.val;
  }
  static inline bool ValueKnown(DimensionOrConstant d) {
    return Value(d) != kUnknownDim;
  }

  // Fills the output proto with the shape defined by the handle.
  // "proto" is expected to be empty prior to the call.
  void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);

  // Returns true if the rank and all dimensions of the Shape are known.
  bool FullyDefined(ShapeHandle s);

  // Returns the total number of elements, or an unknown dimension for an
  // incomplete shape.
  DimensionHandle NumElements(ShapeHandle s);

  string DebugString(ShapeHandle s);
  string DebugString(DimensionHandle d);

  // Describes the whole context, for debugging purposes.
  string DebugString() const;

  // If <shape> has rank <rank>, or its rank is unknown, return OK and return
  // the shape with asserted rank in <*out>. Otherwise return an error.
  //
  // Note that <*out> may be set to <shape>.
  Status WithRank(ShapeHandle shape, int64 rank,
                  ShapeHandle* out) TF_MUST_USE_RESULT;
  Status WithRankAtLeast(ShapeHandle shape, int64 rank,
                         ShapeHandle* out) TF_MUST_USE_RESULT;
  Status WithRankAtMost(ShapeHandle shape, int64 rank,
                        ShapeHandle* out) TF_MUST_USE_RESULT;

  // If <dim> has value <value>, or its value is unknown, returns OK and returns
  // the dimension with asserted value in <*out>. Otherwise returns an error.
  //
  // Note that <*out> may be set to <dim>.
  Status WithValue(DimensionHandle dim, int64 value,
                   DimensionHandle* out) TF_MUST_USE_RESULT;

  // Merges <s0> and <s1> and returns the merged shape in <*out>. See
  // 'MergeInput' function for full details and examples.
  Status Merge(ShapeHandle s0, ShapeHandle s1,
               ShapeHandle* out) TF_MUST_USE_RESULT;

  // Asserts that <s>'s rank >= <prefix>'s rank, and the first
  // <prefix.rank> dimensions of <s> are compatible with the dimensions of
  // <prefix>.
  // Returns the merged results in <*s_out> and <*prefix_out>.
  Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
                     ShapeHandle* prefix_out) TF_MUST_USE_RESULT;

  // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
  // and <d1> have incompatible values, returns an error.
  //
  // Note that <*out> may be set to <d0> or <d1>.
  Status Merge(DimensionHandle d0, DimensionHandle d1,
               DimensionHandle* out) TF_MUST_USE_RESULT;

  // Returns in <*out> a sub-shape of <s> with dimensions [start:].
  // <start> can be negative to index from the end of the shape. If <start> >
  // rank of <s>, then an empty subshape is returned.
  Status Subshape(ShapeHandle s, int64 start,
                  ShapeHandle* out) TF_MUST_USE_RESULT;

  // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
  // <start> and <end> can be negative, to index from the end of the shape.
  // <start> and <end> are set to the rank of <s> if > rank of <s>.
  Status Subshape(ShapeHandle s, int64 start, int64 end,
                  ShapeHandle* out) TF_MUST_USE_RESULT;

  // Returns in <*out> the result of appending the dimensions of <s2> to those
  // of <s1>.
  Status Concatenate(ShapeHandle s1, ShapeHandle s2,
                     ShapeHandle* out) TF_MUST_USE_RESULT;

  // Returns in <out> the shape from replacing <s.dim[dim_index]> with
  // <new_dim>.
  Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
                    ShapeHandle* out) TF_MUST_USE_RESULT;

  // Returns a new shape with the given dims. The returned value is owned by
  // this context.
  ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
  ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);

  // Returns a new unknown shape.
  ShapeHandle UnknownShape();

  // Returns a shape with specified rank but unknown dims.
  ShapeHandle UnknownShapeOfRank(int64 rank);

  // Returns a new shape of zero dimensions.
  ShapeHandle Scalar();

  // Returns a new shape of one dimension.
  ShapeHandle Vector(DimensionOrConstant dim);

  // Returns a new shape of two dimensions.
  ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);

  // Returns in <out> a new shape whose dimension sizes come from input tensor
  // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor.  If
  // the input tensor is NULL, then an unknown shape is returned.
  Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);

  // Returns in <out> a new shape corresponding to <proto>.
  Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
                                 ShapeHandle* out);

  // Returns in <out> a new shape corresponding to <partial_shape>.
  Status MakeShapeFromPartialTensorShape(
      const PartialTensorShape& partial_shape, ShapeHandle* out);

  // Returns in <out> a new shape corresponding to <shape>.
  Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);

  // Returns a new dimension of the given size.  The returned value is owned by
  // this context.
  inline DimensionHandle MakeDim(DimensionOrConstant d) {
    return shape_manager_.MakeDim(d);
  }

  inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }

  // Returns in <val> a scalar value from an input tensor <t>.  The input tensor
  // must be a 1-dimensional int32 or int64 tensor.  Caller must ensure that the
  // input tensor is not NULL.
  Status GetScalarFromTensor(const Tensor* t, int64* val);

  // Returns a new dimension whose value is given by a scalar input tensor.
  // The input tensor must be in host memory, since it is dereferenced to get
  // the value.
  Status MakeDimForScalarInput(int idx, DimensionHandle* out);

  // Returns a new dimension whose value is given by a scalar input tensor.
  // This allows for a negative input dimension given the rank of a separate
  // tensor.  This rank can be negative if unknown.
  // The input tensor must be in host memory, since it is dereferenced to get
  // the value.
  Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
                                                   DimensionHandle* out);

  // Look up the attr for the NodeDef being evaluated with name attr_name and
  // set *value to its value.  If no attr with attr_name is found in def(), or
  // the attr does not have a matching type, a non-ok status will be returned.
  template <class T>
  Status GetAttr(StringPiece attr_name, T* value) const;

  // Returns in <out> the result of dividing <dividend> by <divisor>.
  // Returns an error if <divisor>  is not positive or if <evenly_divisible>
  // and <divisor> does not evenly divide <dividend>.
  Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
                bool evenly_divisible, DimensionHandle* out);

  // Returns in <out> the sum of <first> and <second>.
  Status Add(DimensionHandle first, DimensionOrConstant second,
             DimensionHandle* out);

  // Returns in <out> the dimension that is <first> minus <second>.
  Status Subtract(DimensionHandle first, DimensionOrConstant second,
                  DimensionHandle* out);

  // Returns in <out> the product of <first> and <second>.
  Status Multiply(DimensionHandle first, DimensionOrConstant second,
                  DimensionHandle* out);

  // Returns in <out> the minimum of <first> and <second>. If either <first> or
  // <second> is zero the results is zero. Otherwise, if either <first> or
  // <second> is unknown the results is unknown.
  Status Min(DimensionHandle first, DimensionOrConstant second,
             DimensionHandle* out);

  // Returns in <out> the maximum of <first> and <second>. If either <first> or
  // <second> is unknown the results is unknown.
  Status Max(DimensionHandle first, DimensionOrConstant second,
             DimensionHandle* out);

  Status construction_status() const { return construction_status_; }

  // Methods to propagate shape and dtype on edges of handles. Handles are the
  // dtype DT_RESOURCE which can be used to access state stored in a
  // ResourceManager. When ops (such as variables) consume these handles to
  // produce tensors they might need to know side-information about the shapes
  // and dtypes of tensors which can be accessed via the handle. These methods
  // propagate that information. Output handle dtypes and shapes are ignored if
  // the output tensor is not of type DT_RESOURCE.

  // Merge the stored shapes and types corresponding to the input handle in
  // position idx with the specified shapes and types. This requires idx to be
  // in the [0, num_inputs) range.
  //
  // If the merge is successful and any of the new shapes differs from the old
  // one, or any of the old dtypes was DT_INVALID, store the new shapes and
  // return true.  Return false otherwise.
  //
  // See 'MergeInput' function for full details and examples.
  bool MergeInputHandleShapesAndTypes(
      int idx,
      const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;

  // As MergeInputHandleShapesAndTypes, but for an output.
  bool MergeOutputHandleShapesAndTypes(
      int idx,
      const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;

  // Relaxes the stored shapes and types corresponding to the input handle in
  // position idx with the specified shapes and types. This requires idx to be
  // in the [0, num_inputs) range.
  //
  // If the relax is successful and any of the new shapes differs from the old
  // one, or any of the old dtypes was DT_INVALID, store the new shapes and
  // return true.  Return false otherwise.
  //
  // See 'RelaxInput' function for full details and examples.
  bool RelaxInputHandleShapesAndMergeTypes(
      int idx,
      const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;

  // As RelaxInputHandleShapesAndTypes, but for an output.
  bool RelaxOutputHandleShapesAndMergeTypes(
      int idx,
      const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;

  // Returns the output handle shapes and types, for the resource tensor output
  // at index <idx>. Returns NULL if the shape and types were never set.
  const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
    return output_handle_shapes_and_types_[idx].get();
  }

  // Returns the inputs handle shapes and types, for the resource tensor output
  // at index <idx>. Returns NULL if the shape and types were not available.
  const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
    return input_handle_shapes_and_types_[idx].get();
  }

  void set_output_handle_shapes_and_types(
      int idx, const std::vector<ShapeAndType>& shapes_and_types) {
    output_handle_shapes_and_types_[idx].reset(
        new std::vector<ShapeAndType>(shapes_and_types));
  }

  // Note that shape functions should usually call MakeShapeFromShapeTensor,
  // as it does more analysis to provide partial shapes.
  //
  // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
  // The tensor must be a 1-dimensional int32 or int64 tensor.  If <t> is NULL,
  // then an unknown shape is returned.
  Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
                             ShapeHandle* out);

  int graph_def_version() const { return graph_def_version_; }

  const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
    return merged_shapes_;
  }
  const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
      const {
    return merged_dims_;
  }

 private:
  // Creates and stores shapes for use in InferenceContext.
  class ShapeManager {
   public:
    ShapeManager();
    ~ShapeManager();

    // Returns a new shape with the given dims. The returned value is owned by
    // this class.
    ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);

    // Returns a new unknown shape.
    ShapeHandle UnknownShape();

    // Returns a new dimension of the given size.  The returned value
    // is owned by this class.
    inline DimensionHandle MakeDim(DimensionOrConstant d) {
      if (d.dim.IsSet()) {
        return d.dim;
      } else {
        all_dims_.push_back(new Dimension(d.val));
        return all_dims_.back();
      }
    }

   private:
    std::vector<Shape*> all_shapes_;    // values are owned.
    std::vector<Dimension*> all_dims_;  // values are owned.
  };

  friend class ::tensorflow::grappler::GraphProperties;

  // Friend for user-defined function shape inference purposes.
  friend class ::tensorflow::ShapeRefiner;

  friend class ShapeInferenceTest;      // For testing Relax functions.
  friend class ShapeInferenceTestutil;  // For testing shapes.

  // Shared initialization across the two constructors.  Remove
  // once we get rid of one of them.
  void PreInputInit(const OpDef& op_def,
                    const std::vector<const Tensor*>& input_tensors,
                    const std::vector<ShapeHandle>& input_tensors_as_shapes);
  void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
                         input_handle_data);

  DimensionHandle GetDimension(const DimensionOrConstant& d);

  Status ReturnUnknownShape(ShapeHandle* out) {
    *out = UnknownShape();
    return Status::OK();
  }
  Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
                            ShapeHandle* out) {
    *out = MakeShape(dims);
    return Status::OK();
  }

  // Adds additional context to the given status.
  Status AttachContext(const Status& status);

  // Relaxes an existing value <d_old> with a new value <d_new> and returns the
  // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
  // values, returns an error.
  //
  // Note that <*out> may be set to <d_old> or <d_new>.
  void Relax(DimensionHandle d_old, DimensionHandle d_new,
             DimensionHandle* out);
  // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
  // relaxed shape in <*out>. See 'RelaxInput' function for full details and
  // examples.
  void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);

  // Used to implement MergeInputHandleShapesAndTypes and
  // MergeOutputHandleShapesAndTypes.
  bool MergeHandleShapesAndTypes(
      const std::vector<ShapeAndType>& shapes_and_types,
      std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
  // Used to implement RelaxInputHandleShapesAndMergeTypes and
  // RelaxOutputHandleShapesAndMergeTypes.
  bool RelaxHandleShapesAndMergeTypes(
      const std::vector<ShapeAndType>& shapes_and_types,
      std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;

  // Forget all the previous merged shapes and dims.
  void ForgetMerges() {
    merged_shapes_.clear();
    merged_dims_.clear();
  }

  ShapeManager shape_manager_;

  // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
  // `shape_manager_`.
  std::vector<ShapeHandle> inputs_;
  std::vector<const Tensor*> input_tensors_;
  std::vector<bool> requested_input_tensor_;
  std::vector<ShapeHandle> outputs_;
  // Can have fewer elements than inputs_.
  std::vector<ShapeHandle> input_tensors_as_shapes_;
  std::vector<bool> requested_input_tensor_as_partial_shape_;

  // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
  // through the resource handle passed along input i of the node.
  //
  // Values may be NULL.
  std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
      input_handle_shapes_and_types_;

  // output_handle_shapes_and_types_[i] is the list of shape/type pairs
  // available through the resource handle passed along output i of the node.
  //
  // Values may be NULL.
  std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
      output_handle_shapes_and_types_;

  const int graph_def_version_;
  const NodeDef* node_def_;
  NameRangeMap input_name_map_;
  NameRangeMap output_name_map_;

  // An error set during construction. TODO(cwhipkey): remove when test
  // constructor is removed.
  Status construction_status_;

  // Pair of shape or dim handles that are equivalent, ie that represent the
  // same underlying shape of dimension. Note that for each pair at least one of
  // the handles must contain an unknown shape, since we don't keep track of
  // known shapes or dims here.
  std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
  std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;

  TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
};

// -----------------------------------------------------------------------------
// Template and inline method implementations, please ignore

inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
inline Dimension::Dimension(int64 value) : value_(value) {
  DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
      << "Dimension must be non-negative or equal to "
         "InferenceContext::kUnknownDim but got "
      << value;
}

inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
inline Shape::Shape(const std::vector<DimensionHandle>& dims)
    : rank_(dims.size()), dims_(dims) {}

inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
    : dim(dim) {
  DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
}

inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
  DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
      << "Dimension must be non-negative or equal to "
         "InferenceContext::kUnknownDim but got "
      << val;
}

template <class T>
Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
  return GetNodeAttr(*node_def_, attr_name, value);
}

}  // namespace shape_inference
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_