diff options
-rw-r--r-- | tensorflow/core/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/simple_graph_execution_state.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/grappler/clusters/BUILD | 33 | ||||
-rw-r--r-- | tensorflow/core/grappler/clusters/single_machine.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/clusters/utils.cc | 114 | ||||
-rw-r--r-- | tensorflow/core/grappler/clusters/utils.h | 38 | ||||
-rw-r--r-- | tensorflow/core/grappler/clusters/virtual_cluster.cc | 44 | ||||
-rw-r--r-- | tensorflow/core/grappler/clusters/virtual_cluster.h | 46 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.h | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/utils.cc | 66 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/utils.h | 7 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/meta_optimizer.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/meta_optimizer.h | 2 | ||||
-rw-r--r-- | tensorflow/python/grappler/tf_optimizer.i | 5 |
16 files changed, 300 insertions, 85 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1178d4e5d2..78b67941fe 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1534,6 +1534,8 @@ tf_cuda_library( ":proto_text", ":protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:utils", + "//tensorflow/core/grappler/clusters:virtual_cluster", "//tensorflow/core/grappler/optimizers:meta_optimizer", "//third_party/eigen3", "//tensorflow/core/kernels:required", diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index 31e63a9ef7..590b5d47ba 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -29,6 +29,8 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/graph/validate.h" +#include "tensorflow/core/grappler/clusters/utils.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #include "tensorflow/core/lib/core/errors.h" @@ -267,7 +269,14 @@ Status SimpleGraphExecutionState::InitBaseGraph( } if (s.ok()) { - s = grappler::RunMetaOptimizer(item, rewrite_options, &optimized_graph); + std::unordered_map<string, DeviceProperties> device_map; + for (const auto& device : device_set_->devices()) { + device_map[device->name()] = + grappler::GetDeviceInfo(device->parsed_name()); + } + grappler::VirtualCluster cluster(device_map); + s = grappler::RunMetaOptimizer(item, rewrite_options, &cluster, + &optimized_graph); } if (s.ok()) { graph_def = &optimized_graph; diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index 33a716774f..578ec9798e 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -17,6 +17,8 @@ filegroup( srcs = glob( [ "cluster.*", + "utils.*", + "virtual_cluster.*", ], ), visibility = ["//tensorflow:__subpackages__"], @@ -29,6 +31,21 @@ alias( ) cc_library( + name = "utils", + srcs = ["utils.cc"], + hdrs = [ + "utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//third_party/eigen3", + ], +) + +cc_library( name = "cluster", srcs = ["cluster.cc"], hdrs = [ @@ -45,6 +62,20 @@ cc_library( ) cc_library( + name = "virtual_cluster", + srcs = ["virtual_cluster.cc"], + hdrs = [ + "virtual_cluster.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":cluster", + ":utils", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( name = "single_machine", srcs = ["single_machine.cc"], hdrs = [ @@ -53,6 +84,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":cluster", + ":utils", "//tensorflow/cc:coordinator", "//tensorflow/cc:queue_runner", "//tensorflow/core:core_cpu", @@ -60,7 +92,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:utils", - "//tensorflow/core/grappler/costs:utils", "//tensorflow/core/kernels:ops_util", ], ) diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index 6bb235b836..e255b17631 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/cc/training/queue_runner.h" #include "tensorflow/core/framework/step_stats.pb.h" -#include "tensorflow/core/grappler/costs/utils.h" +#include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc new file mode 100644 index 0000000000..24054d240e --- /dev/null +++ b/tensorflow/core/grappler/clusters/utils.cc @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/clusters/utils.h" + +#include "third_party/eigen3/Eigen/Core" + +#if GOOGLE_CUDA +#include "cuda/include/cuda.h" +#include "cuda/include/cuda_runtime_api.h" +#include "cuda/include/cudnn.h" +#endif + +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cpu_info.h" + +namespace tensorflow { +namespace grappler { + +DeviceProperties GetLocalCPUInfo() { + DeviceProperties device; + device.set_type("CPU"); + + device.set_vendor(port::CPUVendorIDString()); + // Combine cpu family and model into the model string. + device.set_model( + strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum())); + device.set_frequency(port::NominalCPUFrequency() * 1e-6); + device.set_num_cores(port::NumSchedulableCPUs()); + device.set_l1_cache_size(Eigen::l1CacheSize()); + device.set_l2_cache_size(Eigen::l2CacheSize()); + device.set_l3_cache_size(Eigen::l3CacheSize()); + + (*device.mutable_environment())["cpu_instruction_set"] = + Eigen::SimdInstructionSetsInUse(); + + (*device.mutable_environment())["eigen"] = strings::StrCat( + EIGEN_WORLD_VERSION, ".", EIGEN_MAJOR_VERSION, ".", EIGEN_MINOR_VERSION); +#ifdef EIGEN_USE_LIBXSMM + (*device.mutable_environment())["libxsmm"] = LIBXSMM_VERSION; +#endif + + return device; +} + +DeviceProperties GetLocalGPUInfo(int gpu_id) { + DeviceProperties device; + device.set_type("GPU"); + +#if GOOGLE_CUDA + cudaDeviceProp properties; + cudaError_t error = cudaGetDeviceProperties(&properties, gpu_id); + if (error == cudaSuccess) { + device.set_vendor("NVidia"); + device.set_model(properties.name); + device.set_frequency(properties.clockRate * 1e-3); + device.set_num_cores(properties.multiProcessorCount); + device.set_num_registers(properties.regsPerMultiprocessor); + // For compute capability less than 5, l1 cache size is configurable to + // either 16 KB or 48 KB. We use the initial configuration 16 KB here. For + // compute capability larger or equal to 5, l1 cache (unified with texture + // cache) size is 24 KB. This number may need to be updated for future + // compute capabilities. + device.set_l1_cache_size((properties.major < 5) ? 16 * 1024 : 24 * 1024); + device.set_l2_cache_size(properties.l2CacheSize); + device.set_l3_cache_size(0); + device.set_shared_memory_size_per_multiprocessor( + properties.sharedMemPerMultiprocessor); + device.set_memory_size(properties.totalGlobalMem); + // 8 is the number of bits per byte. 2 is accounted for + // double data rate (DDR). + device.set_bandwidth(properties.memoryBusWidth / 8 * + properties.memoryClockRate * 2); + } + + (*device.mutable_environment())["architecture"] = + strings::StrCat(properties.major, ".", properties.minor); + (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); + (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION); +#endif + + return device; +} + +DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device) { + if (device.type == "CPU") { + return GetLocalCPUInfo(); + } else if (device.type == "GPU") { + if (device.has_id) { + return GetLocalGPUInfo(device.id); + } else { + return GetLocalGPUInfo(0); + } + } + DeviceProperties result; + result.set_type("UNKNOWN"); + return result; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h new file mode 100644 index 0000000000..191942040a --- /dev/null +++ b/tensorflow/core/grappler/clusters/utils.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_ +#define TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_ + +#include "tensorflow/core/protobuf/device_properties.pb.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +// Returns the DeviceProperties of the CPU on which grappler is running. +DeviceProperties GetLocalCPUInfo(); + +// Returns the DeviceProperties for the specified GPU attached to the server on +// which grappler is running. +DeviceProperties GetLocalGPUInfo(int gpu_id); + +// Returns the DeviceProperties of the specified device +DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_ diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc new file mode 100644 index 0000000000..4ca4c03dbb --- /dev/null +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" + +namespace tensorflow { +namespace grappler { + +VirtualCluster::VirtualCluster( + const std::unordered_map<string, DeviceProperties>& devices) + : Cluster(0) { + devices_ = devices; +} + +VirtualCluster::~VirtualCluster() {} + +Status VirtualCluster::Provision() { return Status::OK(); } + +Status VirtualCluster::Initialize(const GrapplerItem& item) { + return Status::OK(); +} + +Status VirtualCluster::Run(const GraphDef& item, + const std::vector<std::pair<string, Tensor>>& feed, + const std::vector<string>& fetch, + RunMetadata* metadata) { + return Status::OK(); + +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h new file mode 100644 index 0000000000..cd8436a987 --- /dev/null +++ b/tensorflow/core/grappler/clusters/virtual_cluster.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_ +#define TENSORFLOW_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_ + +#include <unordered_map> +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" + +namespace tensorflow { +namespace grappler { + +// Create a simple cluster that lists the devices (and their properties) +// available in a TensorFlow session. This cluster doesn't allow running an +// actual graph. It is useful however when used in conjusction with costs models +// that aren't based on the execution of the graph. +class VirtualCluster : public Cluster { + public: + VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices); + + ~VirtualCluster() override; + + Status Provision() override; + Status Initialize(const GrapplerItem& item) override; + Status Run(const GraphDef& item, + const std::vector<std::pair<string, Tensor>>& feed, + const std::vector<string>& fetch, RunMetadata* metadata) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_ diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 596b288eb7..ceb4050eae 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -118,6 +118,7 @@ cc_library( deps = [ ":op_performance_data_cc", "//third_party/eigen3", + "//tensorflow/core/grappler/clusters:utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -195,9 +196,8 @@ cc_library( deps = [ ":cost_estimator", ":op_performance_data_cc", - ":utils", - "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core/grappler/clusters:utils", ], ) diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index a6ae2c744c..84e5461d7f 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" #include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/grappler/costs/utils.h" +#include "tensorflow/core/grappler/clusters/utils.h" namespace tensorflow { namespace grappler { diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 7a594e2a01..78ce69a597 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -20,7 +20,6 @@ limitations under the License. #include <map> #include <string> -#include "tensorflow/core/graph/types.h" #include "tensorflow/core/grappler/costs/cost_estimator.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" #include "tensorflow/core/util/padding.h" diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index e3f11272b2..bdfb17a456 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" @@ -139,70 +140,5 @@ DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) { return device; } -DeviceProperties GetLocalCPUInfo() { - DeviceProperties device; - device.set_type("CPU"); - - device.set_vendor(port::CPUVendorIDString()); - // Combine cpu family and model into the model string. - device.set_model( - strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum())); - device.set_frequency(port::NominalCPUFrequency() * 1e-6); - device.set_num_cores(port::NumSchedulableCPUs()); - device.set_l1_cache_size(Eigen::l1CacheSize()); - device.set_l2_cache_size(Eigen::l2CacheSize()); - device.set_l3_cache_size(Eigen::l3CacheSize()); - - (*device.mutable_environment())["cpu_instruction_set"] = - Eigen::SimdInstructionSetsInUse(); - - (*device.mutable_environment())["eigen"] = strings::StrCat( - EIGEN_WORLD_VERSION, ".", EIGEN_MAJOR_VERSION, ".", EIGEN_MINOR_VERSION); -#ifdef EIGEN_USE_LIBXSMM - (*device.mutable_environment())["libxsmm"] = LIBXSMM_VERSION; -#endif - - return device; -} - -DeviceProperties GetLocalGPUInfo(int gpu_id) { - DeviceProperties device; - device.set_type("GPU"); - -#if GOOGLE_CUDA - cudaDeviceProp properties; - cudaError_t error = cudaGetDeviceProperties(&properties, gpu_id); - if (error == cudaSuccess) { - device.set_vendor("NVidia"); - device.set_model(properties.name); - device.set_frequency(properties.clockRate * 1e-3); - device.set_num_cores(properties.multiProcessorCount); - device.set_num_registers(properties.regsPerMultiprocessor); - // For compute capability less than 5, l1 cache size is configurable to - // either 16 KB or 48 KB. We use the initial configuration 16 KB here. For - // compute capability larger or equal to 5, l1 cache (unified with texture - // cache) size is 24 KB. This number may need to be updated for future - // compute capabilities. - device.set_l1_cache_size((properties.major < 5) ? 16 * 1024 : 24 * 1024); - device.set_l2_cache_size(properties.l2CacheSize); - device.set_l3_cache_size(0); - device.set_shared_memory_size_per_multiprocessor( - properties.sharedMemPerMultiprocessor); - device.set_memory_size(properties.totalGlobalMem); - // 8 is the number of bits per byte. 2 is accounted for - // double data rate (DDR). - device.set_bandwidth(properties.memoryBusWidth / 8 * - properties.memoryClockRate * 2); - } - - (*device.mutable_environment())["architecture"] = - strings::StrCat(properties.major, ".", properties.minor); - (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); - (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION); -#endif - - return device; -} - } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h index 0886dfbde3..32e32a09e1 100644 --- a/tensorflow/core/grappler/costs/utils.h +++ b/tensorflow/core/grappler/costs/utils.h @@ -43,13 +43,6 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures( // Returns the DeviceProperties of the device on which 'node' runs. DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node); -// Returns the DeviceProperties of the CPU on which grappler is running. -DeviceProperties GetLocalCPUInfo(); - -// Returns the DeviceProperties for the specified GPU attached to the server on -// which grappler is running. -DeviceProperties GetLocalGPUInfo(int gpu_id); - } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 2ea5adffeb..23adaab8ed 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -90,13 +90,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, bool already_optimized = false; for (const auto& optimizer : optimizers) { if (!already_optimized) { - TF_RETURN_IF_ERROR(optimizer->Optimize(nullptr, item, optimized_graph)); + TF_RETURN_IF_ERROR(optimizer->Optimize(cluster, item, optimized_graph)); already_optimized = true; } else { GrapplerItem optimized_item = item; optimized_item.graph = *optimized_graph; TF_RETURN_IF_ERROR( - optimizer->Optimize(nullptr, optimized_item, optimized_graph)); + optimizer->Optimize(cluster, optimized_item, optimized_graph)); } } // Copy the graph version. @@ -116,9 +116,9 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) { } Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg, - GraphDef* optimized_graph) { + Cluster* cluster, GraphDef* optimized_graph) { MetaOptimizer optimizer(cfg); - return optimizer.Optimize(nullptr, item, optimized_graph); + return optimizer.Optimize(cluster, item, optimized_graph); } } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 9def2cd711..6b950c973d 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -46,7 +46,7 @@ class MetaOptimizer : public GraphOptimizer { bool MetaOptimizerEnabled(const RewriterConfig& cfg); Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg, - GraphDef* optimized_graph); + Cluster* cluster, GraphDef* optimized_graph); } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i index ab887e63e5..404ce35180 100644 --- a/tensorflow/python/grappler/tf_optimizer.i +++ b/tensorflow/python/grappler/tf_optimizer.i @@ -58,6 +58,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item_builder.h" + #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" @@ -69,9 +70,11 @@ PyObject* TF_OptimizeGraph( const tensorflow::grappler::ItemConfig item_config; std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); + std::unordered_map<string, tensorflow::DeviceProperties> device_map; + tensorflow::grappler::VirtualCluster cluster(device_map); tensorflow::GraphDef out_graph; tensorflow::Status status = tensorflow::grappler::RunMetaOptimizer( - *grappler_item, rewriter_config, &out_graph); + *grappler_item, rewriter_config, &cluster, &out_graph); tensorflow::Set_TF_Status_from_Status(out_status, status); string out_graph_str = out_graph.SerializeAsString(); PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(), |