aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/common_runtime/simple_graph_execution_state.cc11
-rw-r--r--tensorflow/core/grappler/clusters/BUILD33
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.cc2
-rw-r--r--tensorflow/core/grappler/clusters/utils.cc114
-rw-r--r--tensorflow/core/grappler/clusters/utils.h38
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.cc44
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.h46
-rw-r--r--tensorflow/core/grappler/costs/BUILD4
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc2
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h1
-rw-r--r--tensorflow/core/grappler/costs/utils.cc66
-rw-r--r--tensorflow/core/grappler/costs/utils.h7
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h2
-rw-r--r--tensorflow/python/grappler/tf_optimizer.i5
16 files changed, 300 insertions, 85 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 1178d4e5d2..78b67941fe 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1534,6 +1534,8 @@ tf_cuda_library(
":proto_text",
":protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/clusters:utils",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
"//third_party/eigen3",
"//tensorflow/core/kernels:required",
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
index 31e63a9ef7..590b5d47ba 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
@@ -29,6 +29,8 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/graph/validate.h"
+#include "tensorflow/core/grappler/clusters/utils.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -267,7 +269,14 @@ Status SimpleGraphExecutionState::InitBaseGraph(
}
if (s.ok()) {
- s = grappler::RunMetaOptimizer(item, rewrite_options, &optimized_graph);
+ std::unordered_map<string, DeviceProperties> device_map;
+ for (const auto& device : device_set_->devices()) {
+ device_map[device->name()] =
+ grappler::GetDeviceInfo(device->parsed_name());
+ }
+ grappler::VirtualCluster cluster(device_map);
+ s = grappler::RunMetaOptimizer(item, rewrite_options, &cluster,
+ &optimized_graph);
}
if (s.ok()) {
graph_def = &optimized_graph;
diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
index 33a716774f..578ec9798e 100644
--- a/tensorflow/core/grappler/clusters/BUILD
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -17,6 +17,8 @@ filegroup(
srcs = glob(
[
"cluster.*",
+ "utils.*",
+ "virtual_cluster.*",
],
),
visibility = ["//tensorflow:__subpackages__"],
@@ -29,6 +31,21 @@ alias(
)
cc_library(
+ name = "utils",
+ srcs = ["utils.cc"],
+ hdrs = [
+ "utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
name = "cluster",
srcs = ["cluster.cc"],
hdrs = [
@@ -45,6 +62,20 @@ cc_library(
)
cc_library(
+ name = "virtual_cluster",
+ srcs = ["virtual_cluster.cc"],
+ hdrs = [
+ "virtual_cluster.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cluster",
+ ":utils",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
name = "single_machine",
srcs = ["single_machine.cc"],
hdrs = [
@@ -53,6 +84,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":cluster",
+ ":utils",
"//tensorflow/cc:coordinator",
"//tensorflow/cc:queue_runner",
"//tensorflow/core:core_cpu",
@@ -60,7 +92,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:utils",
- "//tensorflow/core/grappler/costs:utils",
"//tensorflow/core/kernels:ops_util",
],
)
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index 6bb235b836..e255b17631 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/cc/training/queue_runner.h"
#include "tensorflow/core/framework/step_stats.pb.h"
-#include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc
new file mode 100644
index 0000000000..24054d240e
--- /dev/null
+++ b/tensorflow/core/grappler/clusters/utils.cc
@@ -0,0 +1,114 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/clusters/utils.h"
+
+#include "third_party/eigen3/Eigen/Core"
+
+#if GOOGLE_CUDA
+#include "cuda/include/cuda.h"
+#include "cuda/include/cuda_runtime_api.h"
+#include "cuda/include/cudnn.h"
+#endif
+
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
+
+namespace tensorflow {
+namespace grappler {
+
+DeviceProperties GetLocalCPUInfo() {
+ DeviceProperties device;
+ device.set_type("CPU");
+
+ device.set_vendor(port::CPUVendorIDString());
+ // Combine cpu family and model into the model string.
+ device.set_model(
+ strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum()));
+ device.set_frequency(port::NominalCPUFrequency() * 1e-6);
+ device.set_num_cores(port::NumSchedulableCPUs());
+ device.set_l1_cache_size(Eigen::l1CacheSize());
+ device.set_l2_cache_size(Eigen::l2CacheSize());
+ device.set_l3_cache_size(Eigen::l3CacheSize());
+
+ (*device.mutable_environment())["cpu_instruction_set"] =
+ Eigen::SimdInstructionSetsInUse();
+
+ (*device.mutable_environment())["eigen"] = strings::StrCat(
+ EIGEN_WORLD_VERSION, ".", EIGEN_MAJOR_VERSION, ".", EIGEN_MINOR_VERSION);
+#ifdef EIGEN_USE_LIBXSMM
+ (*device.mutable_environment())["libxsmm"] = LIBXSMM_VERSION;
+#endif
+
+ return device;
+}
+
+DeviceProperties GetLocalGPUInfo(int gpu_id) {
+ DeviceProperties device;
+ device.set_type("GPU");
+
+#if GOOGLE_CUDA
+ cudaDeviceProp properties;
+ cudaError_t error = cudaGetDeviceProperties(&properties, gpu_id);
+ if (error == cudaSuccess) {
+ device.set_vendor("NVidia");
+ device.set_model(properties.name);
+ device.set_frequency(properties.clockRate * 1e-3);
+ device.set_num_cores(properties.multiProcessorCount);
+ device.set_num_registers(properties.regsPerMultiprocessor);
+ // For compute capability less than 5, l1 cache size is configurable to
+ // either 16 KB or 48 KB. We use the initial configuration 16 KB here. For
+ // compute capability larger or equal to 5, l1 cache (unified with texture
+ // cache) size is 24 KB. This number may need to be updated for future
+ // compute capabilities.
+ device.set_l1_cache_size((properties.major < 5) ? 16 * 1024 : 24 * 1024);
+ device.set_l2_cache_size(properties.l2CacheSize);
+ device.set_l3_cache_size(0);
+ device.set_shared_memory_size_per_multiprocessor(
+ properties.sharedMemPerMultiprocessor);
+ device.set_memory_size(properties.totalGlobalMem);
+ // 8 is the number of bits per byte. 2 is accounted for
+ // double data rate (DDR).
+ device.set_bandwidth(properties.memoryBusWidth / 8 *
+ properties.memoryClockRate * 2);
+ }
+
+ (*device.mutable_environment())["architecture"] =
+ strings::StrCat(properties.major, ".", properties.minor);
+ (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION);
+ (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION);
+#endif
+
+ return device;
+}
+
+DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device) {
+ if (device.type == "CPU") {
+ return GetLocalCPUInfo();
+ } else if (device.type == "GPU") {
+ if (device.has_id) {
+ return GetLocalGPUInfo(device.id);
+ } else {
+ return GetLocalGPUInfo(0);
+ }
+ }
+ DeviceProperties result;
+ result.set_type("UNKNOWN");
+ return result;
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h
new file mode 100644
index 0000000000..191942040a
--- /dev/null
+++ b/tensorflow/core/grappler/clusters/utils.h
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_
+#define TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_
+
+#include "tensorflow/core/protobuf/device_properties.pb.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Returns the DeviceProperties of the CPU on which grappler is running.
+DeviceProperties GetLocalCPUInfo();
+
+// Returns the DeviceProperties for the specified GPU attached to the server on
+// which grappler is running.
+DeviceProperties GetLocalGPUInfo(int gpu_id);
+
+// Returns the DeviceProperties of the specified device
+DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device);
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc
new file mode 100644
index 0000000000..4ca4c03dbb
--- /dev/null
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+
+namespace tensorflow {
+namespace grappler {
+
+VirtualCluster::VirtualCluster(
+ const std::unordered_map<string, DeviceProperties>& devices)
+ : Cluster(0) {
+ devices_ = devices;
+}
+
+VirtualCluster::~VirtualCluster() {}
+
+Status VirtualCluster::Provision() { return Status::OK(); }
+
+Status VirtualCluster::Initialize(const GrapplerItem& item) {
+ return Status::OK();
+}
+
+Status VirtualCluster::Run(const GraphDef& item,
+ const std::vector<std::pair<string, Tensor>>& feed,
+ const std::vector<string>& fetch,
+ RunMetadata* metadata) {
+ return Status::OK();
+
+}
+
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h
new file mode 100644
index 0000000000..cd8436a987
--- /dev/null
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_
+#define TENSORFLOW_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_
+
+#include <unordered_map>
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/protobuf/device_properties.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Create a simple cluster that lists the devices (and their properties)
+// available in a TensorFlow session. This cluster doesn't allow running an
+// actual graph. It is useful however when used in conjusction with costs models
+// that aren't based on the execution of the graph.
+class VirtualCluster : public Cluster {
+ public:
+ VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices);
+
+ ~VirtualCluster() override;
+
+ Status Provision() override;
+ Status Initialize(const GrapplerItem& item) override;
+ Status Run(const GraphDef& item,
+ const std::vector<std::pair<string, Tensor>>& feed,
+ const std::vector<string>& fetch, RunMetadata* metadata) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 596b288eb7..ceb4050eae 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -118,6 +118,7 @@ cc_library(
deps = [
":op_performance_data_cc",
"//third_party/eigen3",
+ "//tensorflow/core/grappler/clusters:utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -195,9 +196,8 @@ cc_library(
deps = [
":cost_estimator",
":op_performance_data_cc",
- ":utils",
- "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
+ "//tensorflow/core/grappler/clusters:utils",
],
)
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index a6ae2c744c..84e5461d7f 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
#include "tensorflow/core/framework/attr_value_util.h"
-#include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/grappler/clusters/utils.h"
namespace tensorflow {
namespace grappler {
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 7a594e2a01..78ce69a597 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -20,7 +20,6 @@ limitations under the License.
#include <map>
#include <string>
-#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/util/padding.h"
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index e3f11272b2..bdfb17a456 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
@@ -139,70 +140,5 @@ DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) {
return device;
}
-DeviceProperties GetLocalCPUInfo() {
- DeviceProperties device;
- device.set_type("CPU");
-
- device.set_vendor(port::CPUVendorIDString());
- // Combine cpu family and model into the model string.
- device.set_model(
- strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum()));
- device.set_frequency(port::NominalCPUFrequency() * 1e-6);
- device.set_num_cores(port::NumSchedulableCPUs());
- device.set_l1_cache_size(Eigen::l1CacheSize());
- device.set_l2_cache_size(Eigen::l2CacheSize());
- device.set_l3_cache_size(Eigen::l3CacheSize());
-
- (*device.mutable_environment())["cpu_instruction_set"] =
- Eigen::SimdInstructionSetsInUse();
-
- (*device.mutable_environment())["eigen"] = strings::StrCat(
- EIGEN_WORLD_VERSION, ".", EIGEN_MAJOR_VERSION, ".", EIGEN_MINOR_VERSION);
-#ifdef EIGEN_USE_LIBXSMM
- (*device.mutable_environment())["libxsmm"] = LIBXSMM_VERSION;
-#endif
-
- return device;
-}
-
-DeviceProperties GetLocalGPUInfo(int gpu_id) {
- DeviceProperties device;
- device.set_type("GPU");
-
-#if GOOGLE_CUDA
- cudaDeviceProp properties;
- cudaError_t error = cudaGetDeviceProperties(&properties, gpu_id);
- if (error == cudaSuccess) {
- device.set_vendor("NVidia");
- device.set_model(properties.name);
- device.set_frequency(properties.clockRate * 1e-3);
- device.set_num_cores(properties.multiProcessorCount);
- device.set_num_registers(properties.regsPerMultiprocessor);
- // For compute capability less than 5, l1 cache size is configurable to
- // either 16 KB or 48 KB. We use the initial configuration 16 KB here. For
- // compute capability larger or equal to 5, l1 cache (unified with texture
- // cache) size is 24 KB. This number may need to be updated for future
- // compute capabilities.
- device.set_l1_cache_size((properties.major < 5) ? 16 * 1024 : 24 * 1024);
- device.set_l2_cache_size(properties.l2CacheSize);
- device.set_l3_cache_size(0);
- device.set_shared_memory_size_per_multiprocessor(
- properties.sharedMemPerMultiprocessor);
- device.set_memory_size(properties.totalGlobalMem);
- // 8 is the number of bits per byte. 2 is accounted for
- // double data rate (DDR).
- device.set_bandwidth(properties.memoryBusWidth / 8 *
- properties.memoryClockRate * 2);
- }
-
- (*device.mutable_environment())["architecture"] =
- strings::StrCat(properties.major, ".", properties.minor);
- (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION);
- (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION);
-#endif
-
- return device;
-}
-
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index 0886dfbde3..32e32a09e1 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -43,13 +43,6 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
// Returns the DeviceProperties of the device on which 'node' runs.
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node);
-// Returns the DeviceProperties of the CPU on which grappler is running.
-DeviceProperties GetLocalCPUInfo();
-
-// Returns the DeviceProperties for the specified GPU attached to the server on
-// which grappler is running.
-DeviceProperties GetLocalGPUInfo(int gpu_id);
-
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 2ea5adffeb..23adaab8ed 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -90,13 +90,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
bool already_optimized = false;
for (const auto& optimizer : optimizers) {
if (!already_optimized) {
- TF_RETURN_IF_ERROR(optimizer->Optimize(nullptr, item, optimized_graph));
+ TF_RETURN_IF_ERROR(optimizer->Optimize(cluster, item, optimized_graph));
already_optimized = true;
} else {
GrapplerItem optimized_item = item;
optimized_item.graph = *optimized_graph;
TF_RETURN_IF_ERROR(
- optimizer->Optimize(nullptr, optimized_item, optimized_graph));
+ optimizer->Optimize(cluster, optimized_item, optimized_graph));
}
}
// Copy the graph version.
@@ -116,9 +116,9 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
}
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,
- GraphDef* optimized_graph) {
+ Cluster* cluster, GraphDef* optimized_graph) {
MetaOptimizer optimizer(cfg);
- return optimizer.Optimize(nullptr, item, optimized_graph);
+ return optimizer.Optimize(cluster, item, optimized_graph);
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 9def2cd711..6b950c973d 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -46,7 +46,7 @@ class MetaOptimizer : public GraphOptimizer {
bool MetaOptimizerEnabled(const RewriterConfig& cfg);
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,
- GraphDef* optimized_graph);
+ Cluster* cluster, GraphDef* optimized_graph);
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i
index ab887e63e5..404ce35180 100644
--- a/tensorflow/python/grappler/tf_optimizer.i
+++ b/tensorflow/python/grappler/tf_optimizer.i
@@ -58,6 +58,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
+ #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
@@ -69,9 +70,11 @@ PyObject* TF_OptimizeGraph(
const tensorflow::grappler::ItemConfig item_config;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
+ std::unordered_map<string, tensorflow::DeviceProperties> device_map;
+ tensorflow::grappler::VirtualCluster cluster(device_map);
tensorflow::GraphDef out_graph;
tensorflow::Status status = tensorflow::grappler::RunMetaOptimizer(
- *grappler_item, rewriter_config, &out_graph);
+ *grappler_item, rewriter_config, &cluster, &out_graph);
tensorflow::Set_TF_Status_from_Status(out_status, status);
string out_graph_str = out_graph.SerializeAsString();
PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),