aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/utils.h
diff options
context:
space:
mode:
authorGravatar Peter Ma <pcma@google.com>2018-10-08 23:12:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 23:16:17 -0700
commite27ee15fa45a5f4e43e10ed1fe0eb3a1feb4253a (patch)
tree2588e0531141c95d8c443fa4923d2df20b4970fc /tensorflow/core/grappler/costs/utils.h
parentd1f0494b89a31298df7743018c0a3fa388ac16a2 (diff)
Refactor CalculateOutputSize() from VirtualScheduler protected member function to utils; Refactor EstimateSize() from memory_optimizer.cc to utils; some small changes for readability improvement
PiperOrigin-RevId: 216307257
Diffstat (limited to 'tensorflow/core/grappler/costs/utils.h')
-rw-r--r--tensorflow/core/grappler/costs/utils.h11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index 5fd6717712..ea64e5a41d 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -43,6 +43,17 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost,
const std::unordered_map<string, const NodeDef*>& name_to_node);
+// Returns the size of tensor (unit: bytes). For tensor shape with unknown rank,
+// it assumes the tensor to be scalar. For any unknown dimension, it assumes
+// size one.
+int64 CalculateTensorSize(const OpInfo::TensorProperties& prop);
+
+// Returns the size of output at port_num (unit: bytes). A special case is
+// port_num -1, which is for control dependency and assumed to be 4 bytes.
+int64 CalculateOutputSize(
+ const std::vector<OpInfo::TensorProperties>& output_properties,
+ int port_num);
+
// Returns the DeviceProperties of the device on which 'node' runs.
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node);
DeviceProperties GetDeviceInfo(const string& device_str);