diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/utils.h')
-rw-r--r-- | tensorflow/core/grappler/costs/utils.h | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h index 5fd6717712..ea64e5a41d 100644 --- a/tensorflow/core/grappler/costs/utils.h +++ b/tensorflow/core/grappler/costs/utils.h @@ -43,6 +43,17 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures( const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost, const std::unordered_map<string, const NodeDef*>& name_to_node); +// Returns the size of tensor (unit: bytes). For tensor shape with unknown rank, +// it assumes the tensor to be scalar. For any unknown dimension, it assumes +// size one. +int64 CalculateTensorSize(const OpInfo::TensorProperties& prop); + +// Returns the size of output at port_num (unit: bytes). A special case is +// port_num -1, which is for control dependency and assumed to be 4 bytes. +int64 CalculateOutputSize( + const std::vector<OpInfo::TensorProperties>& output_properties, + int port_num); + // Returns the DeviceProperties of the device on which 'node' runs. DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node); DeviceProperties GetDeviceInfo(const string& device_str); |