diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-11-06 11:12:11 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-06 11:17:02 -0800 |
commit | 4a0eb28a01f72c8182cf0895c7817a4e0137f8f7 (patch) | |
tree | 77173572d8e7ca8ef4976a2e59e116a952d70e27 | |
parent | 0556834abd8a994012d15ad081a850d24ce8fbdd (diff) |
Improved encoding on shapes in grappler.
PiperOrigin-RevId: 174733491
5 files changed, 28 insertions, 5 deletions
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index d1b610d682..b68e6100df 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -71,6 +71,7 @@ class DimensionHandle { friend class ShapeInferenceTestutil; friend class ::tensorflow::ShapeRefinerTest; friend class ShapeManager; + friend class ::tensorflow::grappler::GraphProperties; // Intentionally copyable. }; diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index e9cb2ee09d..741c1fe272 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -394,6 +394,7 @@ Status GraphProperties::InferStatically() { } while (!done); } + std::unordered_map<const shape_inference::Dimension*, int> dim_ids; for (const Node* const node : graph.nodes()) { VLOG(1) << "<Node> " << node->name(); auto ctx = shape_refiner.GetContext(node); @@ -412,7 +413,7 @@ Status GraphProperties::InferStatically() { input_properties.resize(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); ++i) { FillTensorPropertiesFromContext(ctx->input(i), node->input_type(i), ctx, - &input_properties[i]); + &dim_ids, &input_properties[i]); } for (const auto& edge : node->in_edges()) { if (!edge->src()->IsConstant()) { @@ -439,7 +440,7 @@ Status GraphProperties::InferStatically() { output_properties.resize(ctx->num_outputs()); for (int i = 0; i < ctx->num_outputs(); ++i) { FillTensorPropertiesFromContext(ctx->output(i), node->output_type(i), - ctx, &output_properties[i]); + ctx, &dim_ids, &output_properties[i]); } } } @@ -533,6 +534,7 @@ GraphProperties::GetOutputProperties(const string& node_name) const { void GraphProperties::FillTensorPropertiesFromContext( const ShapeHandle& shape, const DataType& type, InferenceContext* ctx, + std::unordered_map<const shape_inference::Dimension*, int>* dim_ids, OpInfo::TensorProperties* properties) { properties->set_dtype(type); if (!ctx->RankKnown(shape)) { @@ -541,6 +543,17 @@ void GraphProperties::FillTensorPropertiesFromContext( for (int j = 0; j < ctx->Rank(shape); ++j) { shape_inference::DimensionHandle dim = ctx->Dim(shape, j); int64 d = ctx->Value(dim); + // Assign a negative id to unknown dimensions, starting at -2 (the -1 id + // reserved by TensorFlow). + if (d < 0) { + auto it = dim_ids->find(dim.ptr_); + if (it != dim_ids->end()) { + d = it->second; + } else { + d = -(dim_ids->size() + 2); + dim_ids->emplace(dim.ptr_, d); + } + } properties->mutable_shape()->add_dim()->set_size(d); } } diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index 5649788be5..92d9574f02 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -42,6 +42,12 @@ class GraphProperties { // Stores `item_.graph` with the inferred output shapes to `output_graph_def`. Status AnnotateOutputShapes(GraphDef* output_graph_def); + // Return the properties of node inputs/outputs, including data types and + // shapes. Note that the dimensions in the shapes can be negative. We use the + // -1 value to denote that we don't know anything about a dimension. We use + // values strictly less than -1 to encode symbolic dimensions: although we + // don't know the actual value of the symbolic dimension, we know that all the + // dimensions denoted by the same negative value are the equal. bool HasInputProperties(const string& name) const; bool HasOutputProperties(const string& name) const; const std::vector<OpInfo::TensorProperties>& GetInputProperties( @@ -51,7 +57,9 @@ class GraphProperties { static void FillTensorPropertiesFromContext( const shape_inference::ShapeHandle&, const DataType&, - shape_inference::InferenceContext*, OpInfo::TensorProperties*); + shape_inference::InferenceContext*, + std::unordered_map<const shape_inference::Dimension*, int>* dim_ids, + OpInfo::TensorProperties*); private: // Inputs diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 134db5ec5a..7fe7d5b511 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -54,7 +54,8 @@ class GraphPropertiesTest : public ::testing::Test { } else { strings::StrAppend(&s, "["); for (int i = 0; i < p.shape().dim_size(); ++i) { - strings::StrAppend(&s, i == 0 ? "" : ",", p.shape().dim(i).size()); + strings::StrAppend(&s, i == 0 ? "" : ",", + std::max<int64>(p.shape().dim(i).size(), -1)); } strings::StrAppend(&s, "]"); } diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index a2fa847df2..bd84331b67 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -98,7 +98,7 @@ TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, } } else { for (int i = 0; i < shape.dim_size(); i++) { - if (shape.dim(i).size() == -1) { + if (shape.dim(i).size() < 0) { *found_unknown_shapes = true; VLOG(2) << "Use minimum dim size 1 because the shape is unknown."; // The size of each dimension is at least 1, if unknown. |