diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/utils.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/utils.cc | 40 |
1 files changed, 39 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 5415324b48..2fcadf1de3 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -74,7 +74,8 @@ static std::vector<TensorProto> ExtractTensors(const AttrValue& attr_value) { } break; } - default: {} + default: { + } } return tensors; } @@ -201,6 +202,43 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures( return inputs; } +int64 CalculateTensorSize(const OpInfo::TensorProperties& prop) { + int64 size = DataTypeSize(BaseType(prop.dtype())); + TensorShapeProto shape = prop.shape(); + + // Can't infer the size if the rank is unknown. It has to be at least a + // scalar though. + if (shape.unknown_rank()) { + LOG(WARNING) << "CalculateTensorSize() -- unknown rank"; + return size; + } + + // If one of the dimensions is unknown statically, assume it's at least one. + for (int i = 0; i < shape.dim_size(); ++i) { + if (shape.dim(i).size() < 0) { + shape.mutable_dim(i)->set_size(1); + LOG(WARNING) << "CalculateTensorSize() -- unknown dim: " << i; + } + } + + int64 num_elems = TensorShape(shape).num_elements(); + return num_elems * size; +} + +int64 CalculateOutputSize( + const std::vector<OpInfo::TensorProperties>& output_properties, + const int port_num) { + if (port_num < 0) return 4; // 4B for control dependency. + + if (port_num >= output_properties.size()) { + LOG(ERROR) << "CalculateOutputSize() -- port_num: " << port_num + << " >= output_properties.size(): " << output_properties.size(); + return 0; + } + + return CalculateTensorSize(output_properties[port_num]); +} + DeviceProperties GetDeviceInfo(const string& device_str) { DeviceProperties unknown; unknown.set_type("UNKNOWN"); |