aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/utils.cc')
-rw-r--r--tensorflow/core/grappler/costs/utils.cc40
1 files changed, 39 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 5415324b48..2fcadf1de3 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -74,7 +74,8 @@ static std::vector<TensorProto> ExtractTensors(const AttrValue& attr_value) {
}
break;
}
- default: {}
+ default: {
+ }
}
return tensors;
}
@@ -201,6 +202,43 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
return inputs;
}
+int64 CalculateTensorSize(const OpInfo::TensorProperties& prop) {
+ int64 size = DataTypeSize(BaseType(prop.dtype()));
+ TensorShapeProto shape = prop.shape();
+
+ // Can't infer the size if the rank is unknown. It has to be at least a
+ // scalar though.
+ if (shape.unknown_rank()) {
+ LOG(WARNING) << "CalculateTensorSize() -- unknown rank";
+ return size;
+ }
+
+ // If one of the dimensions is unknown statically, assume it's at least one.
+ for (int i = 0; i < shape.dim_size(); ++i) {
+ if (shape.dim(i).size() < 0) {
+ shape.mutable_dim(i)->set_size(1);
+ LOG(WARNING) << "CalculateTensorSize() -- unknown dim: " << i;
+ }
+ }
+
+ int64 num_elems = TensorShape(shape).num_elements();
+ return num_elems * size;
+}
+
+int64 CalculateOutputSize(
+ const std::vector<OpInfo::TensorProperties>& output_properties,
+ const int port_num) {
+ if (port_num < 0) return 4; // 4B for control dependency.
+
+ if (port_num >= output_properties.size()) {
+ LOG(ERROR) << "CalculateOutputSize() -- port_num: " << port_num
+ << " >= output_properties.size(): " << output_properties.size();
+ return 0;
+ }
+
+ return CalculateTensorSize(output_properties[port_num]);
+}
+
DeviceProperties GetDeviceInfo(const string& device_str) {
DeviceProperties unknown;
unknown.set_type("UNKNOWN");