From 03d097bc96080981098ffdbaf1b3465e6e153a6a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Oct 2018 17:33:22 -0700 Subject: Consolidate device parameter arguments into a shared DeviceInfo struct PiperOrigin-RevId: 216280197 --- tensorflow/core/grappler/costs/cost_estimator.h | 5 +++++ tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 2 +- tensorflow/core/grappler/costs/op_level_cost_estimator.h | 6 ------ tensorflow/python/grappler/cluster.i | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h index e91f0cc9da..569d9da683 100644 --- a/tensorflow/core/grappler/costs/cost_estimator.h +++ b/tensorflow/core/grappler/costs/cost_estimator.h @@ -30,6 +30,11 @@ struct GrapplerItem; constexpr int64 kMemoryUnknown = -1ll; constexpr int64 kZeroMemory = 0ll; +struct DeviceInfo { + double gigaops; // Billions of operations executed per second. + double gb_per_sec; // Bandwidth to main memory in GB per second. +}; + // Holds the set of things we might want to estimate or measure in Grappler. // Always produce execution time. Other fields are optional depending on the // estimator being used. diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 71f4d9fd05..f363f2915f 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -372,7 +372,7 @@ Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const { return costs; } -OpLevelCostEstimator::DeviceInfo OpLevelCostEstimator::GetDeviceInfo( +DeviceInfo OpLevelCostEstimator::GetDeviceInfo( const DeviceProperties& device) const { double gflops = -1; double gb_per_sec = -1; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index a277dfdf65..dd1ee39cb2 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -40,12 +40,6 @@ class OpLevelCostEstimator { virtual Costs PredictCosts(const OpContext& op_context) const; - // Basic device performance info, sufficient for roofline estimate. - struct DeviceInfo { - double gigaops; // Billions of operations executed per second. - double gb_per_sec; // Bandwidth to main memory in GB per second. - }; - // Returns basic device performance info. virtual DeviceInfo GetDeviceInfo(const DeviceProperties& device) const; diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i index 6816e20407..87795ffcfb 100644 --- a/tensorflow/python/grappler/cluster.i +++ b/tensorflow/python/grappler/cluster.i @@ -308,7 +308,7 @@ static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) { static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) { tensorflow::grappler::OpLevelCostEstimator estimator; - tensorflow::grappler::OpLevelCostEstimator::DeviceInfo info = + tensorflow::grappler::DeviceInfo info = estimator.GetDeviceInfo(device.properties()); return info.gigaops; } -- cgit v1.2.3