aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-03-13 10:24:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-13 11:48:07 -0700
commit93e20c87eb0fbd7b5ae98dc318a056ab3368d699 (patch)
treebb085333c53d008b3686c8a193b936f9b0f713be /tensorflow
parent2ee0a9c9c70b36753977536a7ffca6a4846390f1 (diff)
Improve error messages for resource variable type mismatches.
Generate a C++ wrapper for resource variable ops. Handle dumping graphs with special characters. Don't prune Send operators during XLA compilation. Change: 149966629
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/cc/BUILD10
-rw-r--r--tensorflow/compiler/tf2xla/dump_graph.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/declaration_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc4
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc6
5 files changed, 26 insertions, 9 deletions
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index a4d93e6ae9..9a41d2bb1d 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -376,6 +376,16 @@ tf_gen_op_wrappers_cc(
)
tf_gen_op_wrappers_cc(
+ name = "resource_variable_ops",
+ include_internal_ops = 1,
+ op_lib_names = [
+ "resource_variable_ops",
+ ],
+ pkg = "//tensorflow/core",
+ visibility = ["//tensorflow:internal"],
+)
+
+tf_gen_op_wrappers_cc(
name = "remote_fused_graph_ops",
op_lib_names = [
"remote_fused_graph_ops",
diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc
index 5aa6f806ac..af5753c260 100644
--- a/tensorflow/compiler/tf2xla/dump_graph.cc
+++ b/tensorflow/compiler/tf2xla/dump_graph.cc
@@ -33,8 +33,16 @@ struct NameCounts {
std::unordered_map<string, int> counts;
};
-string MakeUniquePath(const string& name) {
+string MakeUniquePath(string name) {
static NameCounts& instance = *new NameCounts;
+
+ // Remove illegal characters from `name`.
+ for (int i = 0; i < name.size(); ++i) {
+ if (name[i] == '/') {
+ name[i] = '_';
+ }
+ }
+
int count;
{
mutex_lock lock(instance.counts_mutex);
diff --git a/tensorflow/compiler/tf2xla/kernels/declaration_op.cc b/tensorflow/compiler/tf2xla/kernels/declaration_op.cc
index ddb0f7d649..0a86709269 100644
--- a/tensorflow/compiler/tf2xla/kernels/declaration_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/declaration_op.cc
@@ -104,9 +104,8 @@ class ArgOp : public XlaOpKernel {
if (arg.is_variable) {
// We use the argument position of the variable input as a unique ID.
// TODO(phawkins): this code assumes that variables do not alias.
- // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
- tc.CreateVariable(index_, arg.name, arg.value.type, arg.value.handle)
- .IgnoreError();
+ OP_REQUIRES_OK(ctx, tc.CreateVariable(index_, arg.name, arg.value.type,
+ arg.value.handle));
ctx->SetVariableOutput(0, index_);
} else if (arg.value.is_constant) {
ctx->SetConstantOutput(0, arg.value.constant_value);
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 71a994b451..ba975d617d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -76,7 +76,8 @@ int64 XlaCompiler::NextStepId() {
static void PruneUnreachableNodes(Graph* graph) {
std::unordered_set<const Node*> nodes;
for (Node* node : graph->nodes()) {
- if (node->type_string() == "_Retval") {
+ if (node->type_string() == "_Retval" ||
+ StringPiece(node->type_string()).ends_with("Send")) {
nodes.insert(node);
}
}
@@ -379,7 +380,6 @@ Status XlaCompiler::CompileGraph(string const& name,
VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
xla::ComputationBuilder builder(client(), name);
-
XlaContext* context =
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
options_.resolve_compile_time_constants);
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 7316328f44..acab6d7b89 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -65,7 +65,7 @@ REGISTER_OP("ReadVariableOp")
return errors::InvalidArgument(
"Trying to read variable with wrong dtype. "
"Expected ",
- handle_dtype, " got ", value_dtype);
+ DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype));
}
c->set_output(0, c->input_handle_shape(0));
return Status::OK();
@@ -96,7 +96,7 @@ REGISTER_OP("_UnsafeReadVariable")
return errors::InvalidArgument(
"Trying to read variable with wrong dtype. "
"Expected ",
- handle_dtype, " got ", value_dtype);
+ DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype));
}
c->set_output(0, c->input_handle_shape(0));
return Status::OK();
@@ -137,7 +137,7 @@ Status CreateAssignShapeFn(InferenceContext* c) {
return errors::InvalidArgument(
"Trying to initialize handle for variable with wrong dtype. "
"Expected ",
- handle_dtype, " got ", value_dtype);
+ DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype));
}
ShapeHandle s = c->input_handle_shape(0);
ShapeHandle value_shape = c->input(1);