aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-10-01 15:41:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 15:52:53 -0700
commitdc4ac1b84c9c74655f04254779516f9968a5c385 (patch)
treefdf43c0b81c93e5b5116d0087af46087e2cea67a /tensorflow/cc
parent55d96e8ea93407da156c156702a38fd8b5d06b2a (diff)
Clean up the build_xla_ops to use the generated C++ TF op wrappers.
This cleanup will make the future CL implementing lazy compilation simpler. Includes some supporting changes: - Teach NewInternalScope to create a scope that doesn't do shape inference. We need this because we don't have a ShapeRefiner that has been run over the entire graph available in the build_xla_ops pass. - Add a WithAssignedDevice modifier to tensorflow::Scope. - Make cc_op_gen write out an Operation field for nodes which may not necessarily have any outputs. We already did this in most cases, but we weren't doing it for nodes that have possibly-empty list outputs. - Minor change renaming ops/xla_jit_op.cc to ops/xla_jit_ops.cc, now that we have more than one XLA JIT op. PiperOrigin-RevId: 215293817
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc10
-rw-r--r--tensorflow/cc/framework/scope.cc33
-rw-r--r--tensorflow/cc/framework/scope.h4
-rw-r--r--tensorflow/cc/framework/scope_internal.h5
4 files changed, 45 insertions, 7 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index a32d1b1eb5..39593370d1 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -853,11 +853,7 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
}
}
- strings::StrAppend(&class_decl, "\n");
-
- if (output_types.empty()) {
- strings::StrAppend(&class_decl, " Operation operation;\n");
- }
+ strings::StrAppend(&class_decl, "\n Operation operation;\n");
for (int i = 0; i < output_types.size(); ++i) {
strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i],
";\n");
@@ -878,9 +874,11 @@ void OpInfo::GetOutput(string* out) const {
string return_on_error =
strings::StrCat("if (!", scope_str, ".ok()) return;");
+ strings::StrAppend(out, " this->operation = Operation(ret);\n");
+
// No outputs.
if (graph_op_def.output_arg_size() == 0) {
- strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
+ strings::StrAppend(out, " return;\n");
return;
}
if (graph_op_def.output_arg_size() == 1) {
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 7f6ac4cae7..6abc9e268e 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -62,7 +62,7 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
refiner_(refiner),
scope_used_(nullptr),
colocation_constraints_(),
- disable_shape_inference_(false) {}
+ disable_shape_inference_(refiner_ == nullptr) {}
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
@@ -94,6 +94,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -110,6 +111,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -132,6 +134,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -163,6 +166,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -178,6 +182,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
exit_on_error_(true),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -194,6 +199,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(kernel_label),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -210,12 +216,30 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
: other.impl()->GetColocationConstraints(colocate_with_op)),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice,
+ const string& assigned_device)
+ : graph_(other.impl()->graph_),
+ status_(other.impl()->status_),
+ name_map_(other.impl()->name_map_),
+ refiner_(other.impl()->refiner_),
+ scope_used_(other.impl()->scope_used_),
+ control_deps_(other.impl()->control_deps_),
+ name_(other.impl()->name_),
+ op_name_(other.impl()->op_name_),
+ exit_on_error_(other.impl()->exit_on_error_),
+ kernel_label_(other.impl()->kernel_label_),
+ device_(other.impl()->device_),
+ assigned_device_(assigned_device),
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const {
std::unordered_set<string> current_constraints(colocation_constraints_);
@@ -299,6 +323,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const {
if (!impl()->device_.empty()) {
builder->Device(impl()->device_);
}
+ if (!impl()->assigned_device_.empty()) {
+ builder->AssignedDevice(impl()->assigned_device_);
+ }
}
string Scope::Impl::GetUniqueName(const string& prefix,
@@ -394,6 +421,10 @@ Scope Scope::WithDevice(const string& device) const {
return Scope(new Impl(*this, Impl::Tags::Device(), device));
}
+Scope Scope::WithAssignedDevice(const string& assigned_device) const {
+ return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device));
+}
+
Scope Scope::ColocateWith(const Operation& op) const {
return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
/* clear_colocations */ false));
diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h
index 30c32bd44b..e307d8989b 100644
--- a/tensorflow/cc/framework/scope.h
+++ b/tensorflow/cc/framework/scope.h
@@ -133,6 +133,10 @@ class Scope {
/// the device field set to 'device'.
Scope WithDevice(const string& device) const;
+ /// Returns a new scope. All ops created within the returned scope will have
+ /// their assigned device set to `assigned_device`.
+ Scope WithAssignedDevice(const string& assigned_device) const;
+
/// Return a new scope. All ops created within the returned scope will be
/// co-located on the device where op is placed.
/// NOTE: This function is intended to be use internal libraries only for
diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h
index 58adaef2e9..514e02e841 100644
--- a/tensorflow/cc/framework/scope_internal.h
+++ b/tensorflow/cc/framework/scope_internal.h
@@ -26,6 +26,8 @@ class ShapeRefiner;
// graph, status, name_map, and refiner.
// This is intended to enable the C API (which are used by other language
// bindings) to create a Scope and access C++ functionality (i.e. gradients).
+//
+// Shape inference is disabled if `refiner` is nullptr.
Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner);
class Scope::Impl {
@@ -58,6 +60,7 @@ class Scope::Impl {
enum class ExitOnError;
enum class KernelLabel;
enum class Colocate;
+ enum class AssignedDevice;
};
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
@@ -74,6 +77,7 @@ class Scope::Impl {
Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label);
Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op,
bool clear_colocations);
+ Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device);
std::unordered_set<string> GetColocationConstraints(
const Operation& colocate_with_op) const;
@@ -107,6 +111,7 @@ class Scope::Impl {
const bool exit_on_error_ = false;
const string kernel_label_ = "";
const string device_ = "";
+ const string assigned_device_ = "";
const std::unordered_set<string> colocation_constraints_;
// If true, Scope::DoShapeInference() always returns Status:OK().