aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/graph_properties.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties.cc')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc113
1 files changed, 63 insertions, 50 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 35048a4fcf..44322a2d8c 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -50,9 +50,13 @@ template <typename Handle>
struct HandleToObject {};
template <>
struct HandleToObject<ShapeHandle> {
- typedef ShapeHandle Object;
+ typedef TensorShapeProto Object;
- static ShapeHandle Unknown() { return ShapeHandle(); }
+ static TensorShapeProto Unknown() {
+ TensorShapeProto result;
+ result.set_unknown_rank(true);
+ return result;
+ }
};
template <>
@@ -63,24 +67,13 @@ struct HandleToObject<DimensionHandle> {
};
template <typename Handle>
-struct Processor {};
-
-template <>
-struct Processor<ShapeHandle> {
+struct Processor {
// Extract the shape or dim denoted by the handle.
- void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
+ void ExtractValue(Handle /*t1*/,
+ typename HandleToObject<Handle>::Object* result) {}
// Merge the shapes or dims.
- Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
- if (InferenceContext::RankKnown(*result)) {
- // The result was initialized in a previous merge to a shape of known
- // rank, make sure we preserve that information.
- return Status::OK();
- }
- if (InferenceContext::RankKnown(h1)) {
- *result = h1;
- } else {
- *result = h2;
- }
+ Status Merge(Handle /*t1*/, Handle /*t2*/,
+ typename HandleToObject<Handle>::Object* result) {
return Status::OK();
}
};
@@ -108,37 +101,24 @@ struct Processor<DimensionHandle> {
if (dim1 >= 0 && dim2 >= 0) {
CHECK_EQ(dim1, dim2);
- return RefineDim(dim1, result);
+ *result = dim1;
} else if (dim1 >= 0 && dim2 < 0) {
- return RefineDim(dim1, result);
+ *result = dim1;
} else if (dim1 < 0 && dim2 >= 0) {
- return RefineDim(dim2, result);
+ *result = dim2;
} else if (dim1 < -1) {
- return RefineDim(dim1, result);
+ *result = dim1;
} else if (dim2 < -1) {
- return RefineDim(dim2, result);
+ *result = dim2;
} else {
CHECK_EQ(dim1, dim2);
CHECK_EQ(-1, dim1);
- return RefineDim(-1, result);
+ *result = -1;
}
return Status::OK();
}
private:
- Status RefineDim(int64 dim, int64* result) {
- if (*result >= 0) {
- if (!(*result == dim || dim < 0)) {
- return errors::InvalidArgument("Inconsistent dimensions detected");
- }
- } else if (dim >= 0) {
- *result = dim;
- } else if (dim < *result) {
- *result = dim;
- }
- return Status::OK();
- }
-
int64 counter = 2;
};
@@ -374,17 +354,18 @@ class SymbolicShapeManager {
return dims_.Merge(d1, d2);
}
+ int64 Value(DimensionHandle d) { return dims_.GetMergedValue(d); }
+
void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
+ InferenceContext* ctx,
OpInfo::TensorProperties* properties) {
properties->set_dtype(type);
- ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
- if (!InferenceContext::RankKnown(actual_shape)) {
+ if (!ctx->RankKnown(shape)) {
properties->mutable_shape()->set_unknown_rank(true);
} else {
- for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
- shape_inference::DimensionHandle dim =
- InferenceContext::DimKnownRank(actual_shape, j);
- int64 d = dims_.GetMergedValue(dim);
+ for (int j = 0; j < ctx->Rank(shape); ++j) {
+ shape_inference::DimensionHandle dim = ctx->Dim(shape, j);
+ int64 d = Value(dim);
properties->mutable_shape()->add_dim()->set_size(d);
}
}
@@ -466,11 +447,6 @@ Status GraphProperties::InferStatically() {
shape_refiner.set_disable_constant_propagation(true);
shape_refiner.set_function_library_for_shape_inference(&function_library);
ImportGraphDefOptions options;
- // Graph optimization happens at the late stage of graph execution,
- // when colocation constraints are already validated previously and
- // the device placement of nodes has also completed, so there
- // is no need to validate colocation constraints again.
- options.validate_colocation_constraints = false;
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
TF_RETURN_IF_ERROR(s);
@@ -496,6 +472,41 @@ Status GraphProperties::InferStatically() {
}
}
}
+
+ // Infer output shape for Restore op.
+ if (node->op_def().name() == "Restore" ||
+ node->op_def().name() == "RestoreV2" ||
+ node->op_def().name() == "RestoreSlice") {
+ auto ctx = shape_refiner.GetContext(node);
+ for (const Edge* out_edge : node->out_edges()) {
+ const Node* output = out_edge->dst();
+ int output_idx = out_edge->src_output();
+ if (output_idx < 0) {
+ continue;
+ }
+ if (!ctx->FullyDefined(ctx->output(output_idx)) &&
+ output->op_def().name() == "Assign") {
+ if (!output->attrs().Find("validate_shape") ||
+ !output->attrs().Find("validate_shape")->b()) {
+ continue;
+ }
+ auto output_ctx = shape_refiner.GetContext(output);
+ if (output_ctx->FullyDefined(output_ctx->output(0))) {
+ ctx->set_output(output_idx, output_ctx->output(0));
+ output_ctx->MergeInput(1, output_ctx->output(0));
+ } else {
+ const Node* var;
+ TF_CHECK_OK(node->input_node(0, &var));
+ if (node->IsVariable()) {
+ auto var_ctx = shape_refiner.GetContext(var);
+ CHECK(var_ctx->FullyDefined(var_ctx->output(0)));
+ ctx->set_output(output_idx, var_ctx->output(0));
+ output_ctx->MergeInput(1, var_ctx->output(0));
+ }
+ }
+ }
+ }
+ }
}
// Propagate the initial shapes of Enter nodes manually (the Enter shape
@@ -628,6 +639,8 @@ Status GraphProperties::InferStatically() {
} while (!done);
}
+ std::unordered_map<const shape_inference::Dimension*, int> dim_ids;
+
// Track shapes globally accross the graph.
SymbolicShapeManager shape_manager;
bool found_error = false;
@@ -675,7 +688,7 @@ Status GraphProperties::InferStatically() {
input_properties.resize(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
shape_manager.AsTensorProperties(ctx->input(i), node->input_type(i),
- &input_properties[i]);
+ ctx, &input_properties[i]);
}
for (const auto& edge : node->in_edges()) {
if (!edge->src()->IsConstant()) {
@@ -702,7 +715,7 @@ Status GraphProperties::InferStatically() {
output_properties.resize(ctx->num_outputs());
for (int i = 0; i < ctx->num_outputs(); ++i) {
shape_manager.AsTensorProperties(ctx->output(i), node->output_type(i),
- &output_properties[i]);
+ ctx, &output_properties[i]);
}
}
}