aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/utils.h')
-rw-r--r--tensorflow/core/grappler/costs/utils.h11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index 5fd6717712..ea64e5a41d 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -43,6 +43,17 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost,
const std::unordered_map<string, const NodeDef*>& name_to_node);
+// Returns the size of tensor (unit: bytes). For tensor shape with unknown rank,
+// it assumes the tensor to be scalar. For any unknown dimension, it assumes
+// size one.
+int64 CalculateTensorSize(const OpInfo::TensorProperties& prop);
+
+// Returns the size of output at port_num (unit: bytes). A special case is
+// port_num -1, which is for control dependency and assumed to be 4 bytes.
+int64 CalculateOutputSize(
+ const std::vector<OpInfo::TensorProperties>& output_properties,
+ int port_num);
+
// Returns the DeviceProperties of the device on which 'node' runs.
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node);
DeviceProperties GetDeviceInfo(const string& device_str);