diff options
author | 2017-06-20 09:24:25 -0700 | |
---|---|---|
committer | 2017-06-20 09:27:35 -0700 | |
commit | 36567a695c4e1c364c33717036d4d64d33db2ba4 (patch) | |
tree | 1d627d79ba80b04c50db4759c632fcfe7113f40e /tensorflow/core/framework/shape_inference.cc | |
parent | f580c21053c47a12b3b816cc50c261479f36db5d (diff) |
Infer shapes for loops during Grappler static inference
PiperOrigin-RevId: 159570163
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 123 |
1 files changed, 120 insertions, 3 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 1f9e98551f..2c18ddd48f 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -314,6 +314,19 @@ Status InferenceContext::WithValue(DimensionHandle dim, int64 value, existing); } +void InferenceContext::Relax(DimensionHandle d0, DimensionHandle d1, + DimensionHandle* out) { + if (d0.SameHandle(d1)) { + *out = d0; + } else if (!ValueKnown(d0) || !ValueKnown(d1)) { + *out = UnknownDim(); + } else if (Value(d0) == Value(d1)) { + *out = d0; + } else { + *out = UnknownDim(); + } +} + Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1, DimensionHandle* out) { if (d0.SameHandle(d1) || !ValueKnown(d1)) { @@ -356,6 +369,48 @@ Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix, return Status::OK(); } +void InferenceContext::Relax(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out) { + if (s0.SameHandle(s1)) { + *out = s0; + return; + } else if (!RankKnown(s0) || !RankKnown(s1)) { + *out = UnknownShape(); + return; + } + + const int32 rank = Rank(s0); + if (rank != Rank(s1)) { + *out = UnknownShape(); + return; + } + + bool return_s0 = true; + for (int i = 0; i < rank; ++i) { + auto d0 = Dim(s0, i); + auto d1 = Dim(s1, i); + if (d0.SameHandle(d1)) continue; + + auto v0 = Value(d0); + auto v1 = Value(d1); + if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) { + return_s0 = false; + break; + } + } + if (return_s0) { + *out = s0; + return; + } + + // Relax dims. + std::vector<DimensionHandle> dims(rank); + for (int i = 0; i < rank; ++i) { + // Invariant for relax was checked earlier, so CHECK is ok. + Relax(Dim(s0, i), Dim(s1, i), &dims[i]); + } + *out = MakeShape(dims); +} + Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out) { if (s0.SameHandle(s1) || !RankKnown(s1)) { @@ -895,9 +950,15 @@ bool InferenceContext::MergeHandleShapesAndTypes( bool refined = false; for (int i = 0; i < shapes_and_types.size(); ++i) { const ShapeAndType& existing = (*to_update)[i]; - new_values[i].dtype = shapes_and_types[i].dtype; - if (new_values[i].dtype != existing.dtype && existing.dtype == DT_INVALID) { - refined = true; + if (shapes_and_types[i].dtype == existing.dtype) { + new_values[i].dtype = existing.dtype; + } else { + if (existing.dtype != DT_INVALID) { + return false; + } else { + new_values[i].dtype = shapes_and_types[i].dtype; + refined = true; + } } if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape) .ok()) { @@ -939,6 +1000,62 @@ bool InferenceContext::MergeInputHandleShapesAndTypes( input_handle_shapes_and_types_[idx].get()); } +bool InferenceContext::RelaxHandleShapesAndMergeTypes( + const std::vector<ShapeAndType>& shapes_and_types, + std::vector<ShapeAndType>* to_update) { + if (shapes_and_types.size() != to_update->size()) { + return false; + } + std::vector<ShapeAndType> new_values(shapes_and_types.size()); + bool refined = false; + for (int i = 0; i < shapes_and_types.size(); ++i) { + const ShapeAndType& existing = (*to_update)[i]; + if (shapes_and_types[i].dtype == existing.dtype) { + new_values[i].dtype = existing.dtype; + } else { + if (existing.dtype != DT_INVALID) { + return false; + } else { + new_values[i].dtype = shapes_and_types[i].dtype; + refined = true; + } + } + Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape); + if (!existing.shape.SameHandle(new_values[i].shape)) { + refined = true; + } + } + if (!refined) { + return false; + } + for (int i = 0; i < new_values.size(); ++i) { + (*to_update)[i] = new_values[i]; + } + return true; +} + +bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes( + int idx, const std::vector<ShapeAndType>& shapes_and_types) { + if (output_handle_shapes_and_types_[idx] == nullptr) { + output_handle_shapes_and_types_[idx].reset( + new std::vector<ShapeAndType>(shapes_and_types)); + return true; + } + return RelaxHandleShapesAndMergeTypes( + shapes_and_types, output_handle_shapes_and_types_[idx].get()); +} + +bool InferenceContext::RelaxInputHandleShapesAndMergeTypes( + int idx, const std::vector<ShapeAndType>& shapes_and_types) { + if (input_handle_shapes_and_types_[idx] == nullptr) { + input_handle_shapes_and_types_[idx].reset( + new std::vector<ShapeAndType>(shapes_and_types)); + return true; + } + return RelaxHandleShapesAndMergeTypes( + shapes_and_types, input_handle_shapes_and_types_[idx].get()); +} + // ----------------------------------------------------------------------------- // ShapeManager // ----------------------------------------------------------------------------- |