aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-20 09:24:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-20 09:27:35 -0700
commit36567a695c4e1c364c33717036d4d64d33db2ba4 (patch)
tree1d627d79ba80b04c50db4759c632fcfe7113f40e /tensorflow/core/framework/shape_inference.cc
parentf580c21053c47a12b3b816cc50c261479f36db5d (diff)
Infer shapes for loops during Grappler static inference
PiperOrigin-RevId: 159570163
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc123
1 files changed, 120 insertions, 3 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 1f9e98551f..2c18ddd48f 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -314,6 +314,19 @@ Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
existing);
}
+void InferenceContext::Relax(DimensionHandle d0, DimensionHandle d1,
+ DimensionHandle* out) {
+ if (d0.SameHandle(d1)) {
+ *out = d0;
+ } else if (!ValueKnown(d0) || !ValueKnown(d1)) {
+ *out = UnknownDim();
+ } else if (Value(d0) == Value(d1)) {
+ *out = d0;
+ } else {
+ *out = UnknownDim();
+ }
+}
+
Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
DimensionHandle* out) {
if (d0.SameHandle(d1) || !ValueKnown(d1)) {
@@ -356,6 +369,48 @@ Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
return Status::OK();
}
+void InferenceContext::Relax(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out) {
+ if (s0.SameHandle(s1)) {
+ *out = s0;
+ return;
+ } else if (!RankKnown(s0) || !RankKnown(s1)) {
+ *out = UnknownShape();
+ return;
+ }
+
+ const int32 rank = Rank(s0);
+ if (rank != Rank(s1)) {
+ *out = UnknownShape();
+ return;
+ }
+
+ bool return_s0 = true;
+ for (int i = 0; i < rank; ++i) {
+ auto d0 = Dim(s0, i);
+ auto d1 = Dim(s1, i);
+ if (d0.SameHandle(d1)) continue;
+
+ auto v0 = Value(d0);
+ auto v1 = Value(d1);
+ if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
+ return_s0 = false;
+ break;
+ }
+ }
+ if (return_s0) {
+ *out = s0;
+ return;
+ }
+
+ // Relax dims.
+ std::vector<DimensionHandle> dims(rank);
+ for (int i = 0; i < rank; ++i) {
+ // Invariant for relax was checked earlier, so CHECK is ok.
+ Relax(Dim(s0, i), Dim(s1, i), &dims[i]);
+ }
+ *out = MakeShape(dims);
+}
+
Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
ShapeHandle* out) {
if (s0.SameHandle(s1) || !RankKnown(s1)) {
@@ -895,9 +950,15 @@ bool InferenceContext::MergeHandleShapesAndTypes(
bool refined = false;
for (int i = 0; i < shapes_and_types.size(); ++i) {
const ShapeAndType& existing = (*to_update)[i];
- new_values[i].dtype = shapes_and_types[i].dtype;
- if (new_values[i].dtype != existing.dtype && existing.dtype == DT_INVALID) {
- refined = true;
+ if (shapes_and_types[i].dtype == existing.dtype) {
+ new_values[i].dtype = existing.dtype;
+ } else {
+ if (existing.dtype != DT_INVALID) {
+ return false;
+ } else {
+ new_values[i].dtype = shapes_and_types[i].dtype;
+ refined = true;
+ }
}
if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
.ok()) {
@@ -939,6 +1000,62 @@ bool InferenceContext::MergeInputHandleShapesAndTypes(
input_handle_shapes_and_types_[idx].get());
}
+bool InferenceContext::RelaxHandleShapesAndMergeTypes(
+ const std::vector<ShapeAndType>& shapes_and_types,
+ std::vector<ShapeAndType>* to_update) {
+ if (shapes_and_types.size() != to_update->size()) {
+ return false;
+ }
+ std::vector<ShapeAndType> new_values(shapes_and_types.size());
+ bool refined = false;
+ for (int i = 0; i < shapes_and_types.size(); ++i) {
+ const ShapeAndType& existing = (*to_update)[i];
+ if (shapes_and_types[i].dtype == existing.dtype) {
+ new_values[i].dtype = existing.dtype;
+ } else {
+ if (existing.dtype != DT_INVALID) {
+ return false;
+ } else {
+ new_values[i].dtype = shapes_and_types[i].dtype;
+ refined = true;
+ }
+ }
+ Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
+ if (!existing.shape.SameHandle(new_values[i].shape)) {
+ refined = true;
+ }
+ }
+ if (!refined) {
+ return false;
+ }
+ for (int i = 0; i < new_values.size(); ++i) {
+ (*to_update)[i] = new_values[i];
+ }
+ return true;
+}
+
+bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
+ int idx, const std::vector<ShapeAndType>& shapes_and_types) {
+ if (output_handle_shapes_and_types_[idx] == nullptr) {
+ output_handle_shapes_and_types_[idx].reset(
+ new std::vector<ShapeAndType>(shapes_and_types));
+ return true;
+ }
+ return RelaxHandleShapesAndMergeTypes(
+ shapes_and_types, output_handle_shapes_and_types_[idx].get());
+}
+
+bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
+ int idx, const std::vector<ShapeAndType>& shapes_and_types) {
+ if (input_handle_shapes_and_types_[idx] == nullptr) {
+ input_handle_shapes_and_types_[idx].reset(
+ new std::vector<ShapeAndType>(shapes_and_types));
+ return true;
+ }
+ return RelaxHandleShapesAndMergeTypes(
+ shapes_and_types, input_handle_shapes_and_types_[idx].get());
+}
+
// -----------------------------------------------------------------------------
// ShapeManager
// -----------------------------------------------------------------------------