aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-08 17:33:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 17:38:39 -0700
commit03d097bc96080981098ffdbaf1b3465e6e153a6a (patch)
tree3ab671663167deda026298a6e85f09376c4b5d22
parent49643265c3f1f279a93bd8bc3a126e11e979bc44 (diff)
Consolidate device parameter arguments into a shared DeviceInfo struct
PiperOrigin-RevId: 216280197
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h5
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc2
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h6
-rw-r--r--tensorflow/python/grappler/cluster.i2
4 files changed, 7 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
index e91f0cc9da..569d9da683 100644
--- a/tensorflow/core/grappler/costs/cost_estimator.h
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -30,6 +30,11 @@ struct GrapplerItem;
constexpr int64 kMemoryUnknown = -1ll;
constexpr int64 kZeroMemory = 0ll;
+struct DeviceInfo {
+ double gigaops; // Billions of operations executed per second.
+ double gb_per_sec; // Bandwidth to main memory in GB per second.
+};
+
// Holds the set of things we might want to estimate or measure in Grappler.
// Always produce execution time. Other fields are optional depending on the
// estimator being used.
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 71f4d9fd05..f363f2915f 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -372,7 +372,7 @@ Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
return costs;
}
-OpLevelCostEstimator::DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
+DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
const DeviceProperties& device) const {
double gflops = -1;
double gb_per_sec = -1;
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index a277dfdf65..dd1ee39cb2 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -40,12 +40,6 @@ class OpLevelCostEstimator {
virtual Costs PredictCosts(const OpContext& op_context) const;
- // Basic device performance info, sufficient for roofline estimate.
- struct DeviceInfo {
- double gigaops; // Billions of operations executed per second.
- double gb_per_sec; // Bandwidth to main memory in GB per second.
- };
-
// Returns basic device performance info.
virtual DeviceInfo GetDeviceInfo(const DeviceProperties& device) const;
diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i
index 6816e20407..87795ffcfb 100644
--- a/tensorflow/python/grappler/cluster.i
+++ b/tensorflow/python/grappler/cluster.i
@@ -308,7 +308,7 @@ static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) {
static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) {
tensorflow::grappler::OpLevelCostEstimator estimator;
- tensorflow::grappler::OpLevelCostEstimator::DeviceInfo info =
+ tensorflow::grappler::DeviceInfo info =
estimator.GetDeviceInfo(device.properties());
return info.gigaops;
}