aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h3
4 files changed, 21 insertions, 36 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 629872bf19..5dd0b6f4b0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -196,6 +196,8 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
+const char kOutputShapesAttr[] = "_output_shapes";
+
// Shape is symbolically defined if it has a known rank, and each dimension is
// defined, or is an unknown symbol (dim.size <= -2).
bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape) {
@@ -232,19 +234,16 @@ bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left,
// Returns whether `reshape` is an identity op. The tensor that `reshape`
// reshapes is the `output_pos`-th output of node `input`.
bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
- const int output_pos,
- const GraphProperties& graph_properties) {
- const std::vector<OpInfo::TensorProperties>& reshape_props =
- graph_properties.GetOutputProperties(reshape.name());
- const std::vector<OpInfo::TensorProperties>& input_props =
- graph_properties.GetOutputProperties(input.name());
- if (reshape_props.empty() || input_props.empty() ||
- input_props.size() <= output_pos) {
+ const int output_pos) {
+ if (!reshape.attr().count(kOutputShapesAttr) ||
+ !input.attr().count(kOutputShapesAttr)) {
return false;
}
- const PartialTensorShape& src_shape = input_props[output_pos].shape();
- const PartialTensorShape& dst_shape = reshape_props[0].shape();
+ PartialTensorShape src_shape(
+ input.attr().at(kOutputShapesAttr).list().shape(output_pos));
+ PartialTensorShape dst_shape(
+ reshape.attr().at(kOutputShapesAttr).list().shape(0));
if (src_shape.unknown_rank() || dst_shape.unknown_rank()) {
return false;
}
@@ -1273,8 +1272,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// outputs tensors of shape [M, N] while feeding it with tensors of shape
// [M*N] (or worse). The reshape nodes are then necessary to update the
// tensor metadata to the required shape.
- if (can_use_shapes_ &&
- ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_)) {
+ if (ReshapeIsIdentity(*reshape, *input, output_pos)) {
return reshape->input(0);
}
}
@@ -1588,11 +1586,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
- if (options_.combine_add_to_addn && can_use_shapes_) {
+ if (options_.combine_add_to_addn) {
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
new AddOpsRewriteStage(ctx, ctx_ext)));
}
- if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes_) {
+ if (options_.hoist_common_factor_out_of_aggregation) {
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
new HoistCommonFactorOutOfAggregation(ctx, ctx_ext)));
}
@@ -1629,15 +1627,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
if (simplified_tensor.empty()) {
for (auto& stage : stages) {
if (stage->IsSupported(node)) {
- const Status stage_status =
- stage->TrySimplify(node, &simplified_tensor);
- // Each stage must be "error safe" (just like exception safe). In
- // case of any error it must leave optimized graph unmodified.
- if (!stage_status.ok()) {
- LOG(WARNING) << "Failed to run arithmetic optimizer stage "
- << stage->stage_name()
- << ". Error: " << stage_status.error_message();
- }
+ TF_RETURN_IF_ERROR(stage->TrySimplify(node, &simplified_tensor));
if (!simplified_tensor.empty()) {
break;
}
@@ -1704,16 +1694,19 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
&frame_map_, &num_frames));
// Shapes are only needed in aggressive mode.
graph_properties_.reset(new GraphProperties(item));
- const Status status = graph_properties_->InferStatically(false);
- can_use_shapes_ = status.ok();
- if (!can_use_shapes_) {
- LOG(WARNING) << "Shape inference failed.";
- }
+ TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
+ // TODO(ezhulenev): Use GraphProperties to lookup tensor shapes directly
+ TF_RETURN_IF_ERROR(graph_properties_->AnnotateOutputShapes(optimized_graph_));
// Perform the optimizations.
DedupComputations();
TF_RETURN_IF_ERROR(SimplifyArithmeticOps());
+ // Clear output shapes.
+ for (int i = 0; i < optimized_graph->node_size(); ++i) {
+ optimized_graph_->mutable_node(i)->mutable_attr()->erase(kOutputShapesAttr);
+ }
+
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index cdeed0554e..965f0e9ea2 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -126,7 +126,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
RewriterConfig::Toggle opt_level_;
ArithmeticOptimizerOptions options_;
- bool can_use_shapes_ = false;
bool fetch_nodes_known_ = false;
std::unordered_set<string> nodes_to_preserve_;
std::unique_ptr<NodeMap> node_map_;
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc
index 1ea57f7b4f..7044705ade 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc
@@ -42,10 +42,6 @@ Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
Status GetTensorProperties(const GraphOptimizerContext& ctx,
const string& tensor,
OpInfo::TensorProperties* properties) {
- if (ctx.graph_properties == nullptr) {
- return errors::InvalidArgument("Graph properties are unknown.");
- }
-
int port;
string tensor_node_name = ParseNodeName(tensor, &port);
if (port < 0) {
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index c7af82abbb..be95c00d2d 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -117,9 +117,6 @@ class GraphOptimizerStage {
: optimizer_name_(optimizer_name), stage_name_(stage_name), ctx_(ctx) {}
virtual ~GraphOptimizerStage() = default;
- const string& stage_name() const { return stage_name_; }
- const string& optimizer_name() const { return optimizer_name_; }
-
// Check if we should try to simplify node. Returning true doesn't
// guarantee that node will be simplified.
//