aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-11-06 11:12:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-06 11:17:02 -0800
commit4a0eb28a01f72c8182cf0895c7817a4e0137f8f7 (patch)
tree77173572d8e7ca8ef4976a2e59e116a952d70e27
parent0556834abd8a994012d15ad081a850d24ce8fbdd (diff)
Improved encoding on shapes in grappler.
PiperOrigin-RevId: 174733491
-rw-r--r--tensorflow/core/framework/shape_inference.h1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc17
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h10
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc3
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc2
5 files changed, 28 insertions, 5 deletions
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index d1b610d682..b68e6100df 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -71,6 +71,7 @@ class DimensionHandle {
friend class ShapeInferenceTestutil;
friend class ::tensorflow::ShapeRefinerTest;
friend class ShapeManager;
+ friend class ::tensorflow::grappler::GraphProperties;
// Intentionally copyable.
};
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index e9cb2ee09d..741c1fe272 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -394,6 +394,7 @@ Status GraphProperties::InferStatically() {
} while (!done);
}
+ std::unordered_map<const shape_inference::Dimension*, int> dim_ids;
for (const Node* const node : graph.nodes()) {
VLOG(1) << "<Node> " << node->name();
auto ctx = shape_refiner.GetContext(node);
@@ -412,7 +413,7 @@ Status GraphProperties::InferStatically() {
input_properties.resize(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
FillTensorPropertiesFromContext(ctx->input(i), node->input_type(i), ctx,
- &input_properties[i]);
+ &dim_ids, &input_properties[i]);
}
for (const auto& edge : node->in_edges()) {
if (!edge->src()->IsConstant()) {
@@ -439,7 +440,7 @@ Status GraphProperties::InferStatically() {
output_properties.resize(ctx->num_outputs());
for (int i = 0; i < ctx->num_outputs(); ++i) {
FillTensorPropertiesFromContext(ctx->output(i), node->output_type(i),
- ctx, &output_properties[i]);
+ ctx, &dim_ids, &output_properties[i]);
}
}
}
@@ -533,6 +534,7 @@ GraphProperties::GetOutputProperties(const string& node_name) const {
void GraphProperties::FillTensorPropertiesFromContext(
const ShapeHandle& shape, const DataType& type, InferenceContext* ctx,
+ std::unordered_map<const shape_inference::Dimension*, int>* dim_ids,
OpInfo::TensorProperties* properties) {
properties->set_dtype(type);
if (!ctx->RankKnown(shape)) {
@@ -541,6 +543,17 @@ void GraphProperties::FillTensorPropertiesFromContext(
for (int j = 0; j < ctx->Rank(shape); ++j) {
shape_inference::DimensionHandle dim = ctx->Dim(shape, j);
int64 d = ctx->Value(dim);
+ // Assign a negative id to unknown dimensions, starting at -2 (the -1 id
+ // reserved by TensorFlow).
+ if (d < 0) {
+ auto it = dim_ids->find(dim.ptr_);
+ if (it != dim_ids->end()) {
+ d = it->second;
+ } else {
+ d = -(dim_ids->size() + 2);
+ dim_ids->emplace(dim.ptr_, d);
+ }
+ }
properties->mutable_shape()->add_dim()->set_size(d);
}
}
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 5649788be5..92d9574f02 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -42,6 +42,12 @@ class GraphProperties {
// Stores `item_.graph` with the inferred output shapes to `output_graph_def`.
Status AnnotateOutputShapes(GraphDef* output_graph_def);
+ // Return the properties of node inputs/outputs, including data types and
+ // shapes. Note that the dimensions in the shapes can be negative. We use the
+ // -1 value to denote that we don't know anything about a dimension. We use
+ // values strictly less than -1 to encode symbolic dimensions: although we
+ // don't know the actual value of the symbolic dimension, we know that all the
+ // dimensions denoted by the same negative value are the equal.
bool HasInputProperties(const string& name) const;
bool HasOutputProperties(const string& name) const;
const std::vector<OpInfo::TensorProperties>& GetInputProperties(
@@ -51,7 +57,9 @@ class GraphProperties {
static void FillTensorPropertiesFromContext(
const shape_inference::ShapeHandle&, const DataType&,
- shape_inference::InferenceContext*, OpInfo::TensorProperties*);
+ shape_inference::InferenceContext*,
+ std::unordered_map<const shape_inference::Dimension*, int>* dim_ids,
+ OpInfo::TensorProperties*);
private:
// Inputs
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 134db5ec5a..7fe7d5b511 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -54,7 +54,8 @@ class GraphPropertiesTest : public ::testing::Test {
} else {
strings::StrAppend(&s, "[");
for (int i = 0; i < p.shape().dim_size(); ++i) {
- strings::StrAppend(&s, i == 0 ? "" : ",", p.shape().dim(i).size());
+ strings::StrAppend(&s, i == 0 ? "" : ",",
+ std::max<int64>(p.shape().dim(i).size(), -1));
}
strings::StrAppend(&s, "]");
}
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index a2fa847df2..bd84331b67 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -98,7 +98,7 @@ TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
}
} else {
for (int i = 0; i < shape.dim_size(); i++) {
- if (shape.dim(i).size() == -1) {
+ if (shape.dim(i).size() < 0) {
*found_unknown_shapes = true;
VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
// The size of each dimension is at least 1, if unknown.