diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties.cc | 113 |
1 files changed, 63 insertions, 50 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 35048a4fcf..44322a2d8c 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -50,9 +50,13 @@ template <typename Handle> struct HandleToObject {}; template <> struct HandleToObject<ShapeHandle> { - typedef ShapeHandle Object; + typedef TensorShapeProto Object; - static ShapeHandle Unknown() { return ShapeHandle(); } + static TensorShapeProto Unknown() { + TensorShapeProto result; + result.set_unknown_rank(true); + return result; + } }; template <> @@ -63,24 +67,13 @@ struct HandleToObject<DimensionHandle> { }; template <typename Handle> -struct Processor {}; - -template <> -struct Processor<ShapeHandle> { +struct Processor { // Extract the shape or dim denoted by the handle. - void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; } + void ExtractValue(Handle /*t1*/, + typename HandleToObject<Handle>::Object* result) {} // Merge the shapes or dims. - Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) { - if (InferenceContext::RankKnown(*result)) { - // The result was initialized in a previous merge to a shape of known - // rank, make sure we preserve that information. - return Status::OK(); - } - if (InferenceContext::RankKnown(h1)) { - *result = h1; - } else { - *result = h2; - } + Status Merge(Handle /*t1*/, Handle /*t2*/, + typename HandleToObject<Handle>::Object* result) { return Status::OK(); } }; @@ -108,37 +101,24 @@ struct Processor<DimensionHandle> { if (dim1 >= 0 && dim2 >= 0) { CHECK_EQ(dim1, dim2); - return RefineDim(dim1, result); + *result = dim1; } else if (dim1 >= 0 && dim2 < 0) { - return RefineDim(dim1, result); + *result = dim1; } else if (dim1 < 0 && dim2 >= 0) { - return RefineDim(dim2, result); + *result = dim2; } else if (dim1 < -1) { - return RefineDim(dim1, result); + *result = dim1; } else if (dim2 < -1) { - return RefineDim(dim2, result); + *result = dim2; } else { CHECK_EQ(dim1, dim2); CHECK_EQ(-1, dim1); - return RefineDim(-1, result); + *result = -1; } return Status::OK(); } private: - Status RefineDim(int64 dim, int64* result) { - if (*result >= 0) { - if (!(*result == dim || dim < 0)) { - return errors::InvalidArgument("Inconsistent dimensions detected"); - } - } else if (dim >= 0) { - *result = dim; - } else if (dim < *result) { - *result = dim; - } - return Status::OK(); - } - int64 counter = 2; }; @@ -374,17 +354,18 @@ class SymbolicShapeManager { return dims_.Merge(d1, d2); } + int64 Value(DimensionHandle d) { return dims_.GetMergedValue(d); } + void AsTensorProperties(const ShapeHandle& shape, const DataType& type, + InferenceContext* ctx, OpInfo::TensorProperties* properties) { properties->set_dtype(type); - ShapeHandle actual_shape = shapes_.GetMergedValue(shape); - if (!InferenceContext::RankKnown(actual_shape)) { + if (!ctx->RankKnown(shape)) { properties->mutable_shape()->set_unknown_rank(true); } else { - for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) { - shape_inference::DimensionHandle dim = - InferenceContext::DimKnownRank(actual_shape, j); - int64 d = dims_.GetMergedValue(dim); + for (int j = 0; j < ctx->Rank(shape); ++j) { + shape_inference::DimensionHandle dim = ctx->Dim(shape, j); + int64 d = Value(dim); properties->mutable_shape()->add_dim()->set_size(d); } } @@ -466,11 +447,6 @@ Status GraphProperties::InferStatically() { shape_refiner.set_disable_constant_propagation(true); shape_refiner.set_function_library_for_shape_inference(&function_library); ImportGraphDefOptions options; - // Graph optimization happens at the late stage of graph execution, - // when colocation constraints are already validated previously and - // the device placement of nodes has also completed, so there - // is no need to validate colocation constraints again. - options.validate_colocation_constraints = false; Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); TF_RETURN_IF_ERROR(s); @@ -496,6 +472,41 @@ Status GraphProperties::InferStatically() { } } } + + // Infer output shape for Restore op. + if (node->op_def().name() == "Restore" || + node->op_def().name() == "RestoreV2" || + node->op_def().name() == "RestoreSlice") { + auto ctx = shape_refiner.GetContext(node); + for (const Edge* out_edge : node->out_edges()) { + const Node* output = out_edge->dst(); + int output_idx = out_edge->src_output(); + if (output_idx < 0) { + continue; + } + if (!ctx->FullyDefined(ctx->output(output_idx)) && + output->op_def().name() == "Assign") { + if (!output->attrs().Find("validate_shape") || + !output->attrs().Find("validate_shape")->b()) { + continue; + } + auto output_ctx = shape_refiner.GetContext(output); + if (output_ctx->FullyDefined(output_ctx->output(0))) { + ctx->set_output(output_idx, output_ctx->output(0)); + output_ctx->MergeInput(1, output_ctx->output(0)); + } else { + const Node* var; + TF_CHECK_OK(node->input_node(0, &var)); + if (node->IsVariable()) { + auto var_ctx = shape_refiner.GetContext(var); + CHECK(var_ctx->FullyDefined(var_ctx->output(0))); + ctx->set_output(output_idx, var_ctx->output(0)); + output_ctx->MergeInput(1, var_ctx->output(0)); + } + } + } + } + } } // Propagate the initial shapes of Enter nodes manually (the Enter shape @@ -628,6 +639,8 @@ Status GraphProperties::InferStatically() { } while (!done); } + std::unordered_map<const shape_inference::Dimension*, int> dim_ids; + // Track shapes globally accross the graph. SymbolicShapeManager shape_manager; bool found_error = false; @@ -675,7 +688,7 @@ Status GraphProperties::InferStatically() { input_properties.resize(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); ++i) { shape_manager.AsTensorProperties(ctx->input(i), node->input_type(i), - &input_properties[i]); + ctx, &input_properties[i]); } for (const auto& edge : node->in_edges()) { if (!edge->src()->IsConstant()) { @@ -702,7 +715,7 @@ Status GraphProperties::InferStatically() { output_properties.resize(ctx->num_outputs()); for (int i = 0; i < ctx->num_outputs(); ++i) { shape_manager.AsTensorProperties(ctx->output(i), node->output_type(i), - &output_properties[i]); + ctx, &output_properties[i]); } } } |