diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-08 17:33:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 17:38:39 -0700 |
commit | 03d097bc96080981098ffdbaf1b3465e6e153a6a (patch) | |
tree | 3ab671663167deda026298a6e85f09376c4b5d22 | |
parent | 49643265c3f1f279a93bd8bc3a126e11e979bc44 (diff) |
Consolidate device parameter arguments into a shared DeviceInfo struct
PiperOrigin-RevId: 216280197
4 files changed, 7 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h index e91f0cc9da..569d9da683 100644 --- a/tensorflow/core/grappler/costs/cost_estimator.h +++ b/tensorflow/core/grappler/costs/cost_estimator.h @@ -30,6 +30,11 @@ struct GrapplerItem; constexpr int64 kMemoryUnknown = -1ll; constexpr int64 kZeroMemory = 0ll; +struct DeviceInfo { + double gigaops; // Billions of operations executed per second. + double gb_per_sec; // Bandwidth to main memory in GB per second. +}; + // Holds the set of things we might want to estimate or measure in Grappler. // Always produce execution time. Other fields are optional depending on the // estimator being used. diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 71f4d9fd05..f363f2915f 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -372,7 +372,7 @@ Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const { return costs; } -OpLevelCostEstimator::DeviceInfo OpLevelCostEstimator::GetDeviceInfo( +DeviceInfo OpLevelCostEstimator::GetDeviceInfo( const DeviceProperties& device) const { double gflops = -1; double gb_per_sec = -1; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index a277dfdf65..dd1ee39cb2 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -40,12 +40,6 @@ class OpLevelCostEstimator { virtual Costs PredictCosts(const OpContext& op_context) const; - // Basic device performance info, sufficient for roofline estimate. - struct DeviceInfo { - double gigaops; // Billions of operations executed per second. - double gb_per_sec; // Bandwidth to main memory in GB per second. - }; - // Returns basic device performance info. virtual DeviceInfo GetDeviceInfo(const DeviceProperties& device) const; diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i index 6816e20407..87795ffcfb 100644 --- a/tensorflow/python/grappler/cluster.i +++ b/tensorflow/python/grappler/cluster.i @@ -308,7 +308,7 @@ static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) { static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) { tensorflow::grappler::OpLevelCostEstimator estimator; - tensorflow::grappler::OpLevelCostEstimator::DeviceInfo info = + tensorflow::grappler::DeviceInfo info = estimator.GetDeviceInfo(device.properties()); return info.gigaops; } |