aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework/scope.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/framework/scope.cc')
-rw-r--r--tensorflow/cc/framework/scope.cc33
1 files changed, 32 insertions, 1 deletions
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 7f6ac4cae7..6abc9e268e 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -62,7 +62,7 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
refiner_(refiner),
scope_used_(nullptr),
colocation_constraints_(),
- disable_shape_inference_(false) {}
+ disable_shape_inference_(refiner_ == nullptr) {}
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
@@ -94,6 +94,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -110,6 +111,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -132,6 +134,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -163,6 +166,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -178,6 +182,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
exit_on_error_(true),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -194,6 +199,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(kernel_label),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -210,12 +216,30 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
: other.impl()->GetColocationConstraints(colocate_with_op)),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice,
+ const string& assigned_device)
+ : graph_(other.impl()->graph_),
+ status_(other.impl()->status_),
+ name_map_(other.impl()->name_map_),
+ refiner_(other.impl()->refiner_),
+ scope_used_(other.impl()->scope_used_),
+ control_deps_(other.impl()->control_deps_),
+ name_(other.impl()->name_),
+ op_name_(other.impl()->op_name_),
+ exit_on_error_(other.impl()->exit_on_error_),
+ kernel_label_(other.impl()->kernel_label_),
+ device_(other.impl()->device_),
+ assigned_device_(assigned_device),
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const {
std::unordered_set<string> current_constraints(colocation_constraints_);
@@ -299,6 +323,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const {
if (!impl()->device_.empty()) {
builder->Device(impl()->device_);
}
+ if (!impl()->assigned_device_.empty()) {
+ builder->AssignedDevice(impl()->assigned_device_);
+ }
}
string Scope::Impl::GetUniqueName(const string& prefix,
@@ -394,6 +421,10 @@ Scope Scope::WithDevice(const string& device) const {
return Scope(new Impl(*this, Impl::Tags::Device(), device));
}
+Scope Scope::WithAssignedDevice(const string& assigned_device) const {
+ return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device));
+}
+
Scope Scope::ColocateWith(const Operation& op) const {
return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
/* clear_colocations */ false));