aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/constant_folding.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/constant_folding.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc172
1 files changed, 11 insertions, 161 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 02a732b092..cb02314183 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/version.h"
-#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace grappler {
@@ -96,15 +95,11 @@ class DeviceSimple : public DeviceBase {
};
} // namespace
-ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
- DeviceBase* cpu_device)
- : opt_level_(opt_level), cpu_device_(cpu_device) {
+ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
+ : cpu_device_(cpu_device) {
resource_mgr_.reset(new ResourceMgr());
}
-ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
- : ConstantFolding(RewriterConfig::ON, cpu_device) {}
-
// static
string ConstantFolding::AddControlDependency(const string& input_name,
GraphDef* graph,
@@ -286,149 +281,6 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
return Status::OK();
}
-bool ShapesEqual(const TensorShapeProto& shape1,
- const TensorShapeProto& shape2) {
- if (shape1.unknown_rank() || shape2.unknown_rank()) {
- return false;
- }
- if (shape1.dim_size() != shape2.dim_size()) {
- return false;
- }
- for (int i = 0; i < shape1.dim_size(); ++i) {
- if (shape1.dim(i).size() != shape2.dim(i).size()) {
- return false;
- }
- }
- return true;
-}
-
-namespace {
-bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
- BCast::Vec* shape, int64* min_id) {
- if (shape_node.op() == "Shape") {
- const std::vector<OpInfo::TensorProperties>& prop1 =
- properties.GetInputProperties(shape_node.name());
- if (prop1.size() != 1) {
- return false;
- }
- const TensorShapeProto& shp = prop1[0].shape();
- if (shp.unknown_rank()) {
- return false;
- }
- for (const auto& dim : shp.dim()) {
- shape->push_back(dim.size());
- *min_id = std::min<int64>(*min_id, dim.size());
- }
- } else {
- const TensorProto& raw_val = shape_node.attr().at("value").tensor();
- if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
- return false;
- }
- Tensor value(raw_val.dtype(), raw_val.tensor_shape());
- if (!value.FromProto(raw_val)) {
- return false;
- }
- for (int j = 0; j < value.NumElements(); ++j) {
- if (raw_val.dtype() == DT_INT64) {
- shape->push_back(value.vec<int64>()(j));
- } else {
- shape->push_back(value.vec<int>()(j));
- }
- }
- }
- return true;
-}
-} // namespace
-
-Status ConstantFolding::MaterializeConstants(
- const GrapplerItem& item, const GraphProperties& properties) {
- const int node_count = graph_.node_size();
- for (int i = 0; i < node_count; ++i) {
- NodeDef& node = *graph_.mutable_node(i);
- const string& op = node.op();
- if (op != "BroadcastGradientArgs") {
- continue;
- }
- const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
- const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
- if (shape_node1 == nullptr ||
- (shape_node1->op() != "Shape" && shape_node1->op() != "Const") ||
- shape_node2 == nullptr ||
- (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) {
- continue;
- }
- int64 min_id = 0;
- BCast::Vec shape1;
- if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
- continue;
- }
- BCast::Vec shape2;
- if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
- continue;
- }
- // A value of -1 means we don't known anything about the dimension. Replace
- // the -1 values with unique dimension ids since we don't want two '-1'
- // dimensions to be considered equal.
- for (auto& id : shape1) {
- if (id == -1) {
- id = --min_id;
- }
- }
- for (auto& id : shape2) {
- if (id == -1) {
- id = --min_id;
- }
- }
- BCast bcast(shape1, shape2);
- if (!bcast.IsValid()) {
- continue;
- }
- BCast::Vec reduce_dims[2];
- reduce_dims[0] = bcast.grad_x_reduce_idx();
- reduce_dims[1] = bcast.grad_y_reduce_idx();
-
- const DataType type = node.attr().at("T").type();
- NodeDef* out[2];
- for (int j = 0; j < 2; ++j) {
- if (!reduce_dims[j].empty()) {
- // This is the case when a tensor dimension 1 is matched against an
- // unknown dimension. The unknown dimension could also be equal to 1, in
- // which case there would be no reduction.
- out[j] = nullptr;
- } else {
- Tensor value(type, TensorShape({0}));
- string const_name = AddPrefixToNodeName(
- strings::StrCat(node.name(), "-", j), kConstantFoldingConst);
- out[j] = node_map_->GetNode(const_name);
- if (!out[j]) {
- out[j] = graph_.add_node();
- *out[j] = CreateNodeDef(const_name, TensorValue(&value));
- out[j]->set_device(node.device());
- node_map_->AddNode(const_name, out[j]);
- string ctrl_dep =
- AddControlDependency(node.name(), &graph_, node_map_.get());
- *out[j]->add_input() = ctrl_dep;
- node_map_->AddOutput(NodeName(ctrl_dep), const_name);
- }
- }
- }
-
- auto outputs = node_map_->GetOutputs(node.name());
- for (const auto& output : outputs) {
- for (int k = 0; k < output->input_size(); ++k) {
- int port;
- string node_name = ParseNodeName(output->input(k), &port);
- if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
- *output->mutable_input(k) = out[port]->name();
- node_map_->UpdateInput(output->name(), node_name, out[port]->name());
- }
- }
- }
- }
-
- return Status::OK();
-}
-
bool ConstantFolding::IsFoldable(const NodeDef& node) const {
// Folding not applicable to ops with no inputs.
if (node.input().empty()) {
@@ -1069,23 +921,23 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
}
GraphProperties properties(item);
- Status s = properties.InferStatically();
bool has_feed = !item.feed.empty();
-
- if (!has_feed && s.ok()) {
+ if (!has_feed) {
// Only use static shape information when there is no feed in the
// graph. That's because it's possible to feed a placeholder with a tensor
// of any shape, which could make the static information inconsistent with
// the shapes actually fed.
- TF_RETURN_IF_ERROR(MaterializeShapes(item, properties));
- }
- if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) {
- TF_RETURN_IF_ERROR(MaterializeConstants(item, properties));
+ Status s = properties.InferStatically();
+ if (!s.ok()) {
+ VLOG(1) << "Failed to infer graph shapes: " << s;
+ } else {
+ TF_RETURN_IF_ERROR(MaterializeShapes(item, properties));
+ }
}
TF_RETURN_IF_ERROR(FoldGraph(output));
- if (!has_feed && s.ok()) {
+ if (!has_feed) {
TF_RETURN_IF_ERROR(SimplifyGraph(output, properties));
}
return Status::OK();
@@ -1104,14 +956,12 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GrapplerItem item_to_optimize = item;
*output = item.graph;
- int64 node_count;
do {
graph_.Swap(output);
item_to_optimize.graph = graph_;
*output = GraphDef();
- node_count = graph_.node_size();
TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, item_to_optimize, output));
- } while (output->node_size() != node_count);
+ } while (output->node_size() < graph_.node_size());
*output->mutable_library() = item.graph.library();
*output->mutable_versions() = item.graph.versions();