diff options
author | 2017-03-13 10:24:30 -0800 | |
---|---|---|
committer | 2017-03-13 11:48:07 -0700 | |
commit | 93e20c87eb0fbd7b5ae98dc318a056ab3368d699 (patch) | |
tree | bb085333c53d008b3686c8a193b936f9b0f713be /tensorflow | |
parent | 2ee0a9c9c70b36753977536a7ffca6a4846390f1 (diff) |
Improve error messages for resource variable type mismatches.
Generate a C++ wrapper for resource variable ops.
Handle dumping graphs with special characters.
Don't prune Send operators during XLA compilation.
Change: 149966629
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/cc/BUILD | 10 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/dump_graph.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/declaration_op.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/ops/resource_variable_ops.cc | 6 |
5 files changed, 26 insertions, 9 deletions
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a4d93e6ae9..9a41d2bb1d 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -376,6 +376,16 @@ tf_gen_op_wrappers_cc( ) tf_gen_op_wrappers_cc( + name = "resource_variable_ops", + include_internal_ops = 1, + op_lib_names = [ + "resource_variable_ops", + ], + pkg = "//tensorflow/core", + visibility = ["//tensorflow:internal"], +) + +tf_gen_op_wrappers_cc( name = "remote_fused_graph_ops", op_lib_names = [ "remote_fused_graph_ops", diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 5aa6f806ac..af5753c260 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -33,8 +33,16 @@ struct NameCounts { std::unordered_map<string, int> counts; }; -string MakeUniquePath(const string& name) { +string MakeUniquePath(string name) { static NameCounts& instance = *new NameCounts; + + // Remove illegal characters from `name`. + for (int i = 0; i < name.size(); ++i) { + if (name[i] == '/') { + name[i] = '_'; + } + } + int count; { mutex_lock lock(instance.counts_mutex); diff --git a/tensorflow/compiler/tf2xla/kernels/declaration_op.cc b/tensorflow/compiler/tf2xla/kernels/declaration_op.cc index ddb0f7d649..0a86709269 100644 --- a/tensorflow/compiler/tf2xla/kernels/declaration_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/declaration_op.cc @@ -104,9 +104,8 @@ class ArgOp : public XlaOpKernel { if (arg.is_variable) { // We use the argument position of the variable input as a unique ID. // TODO(phawkins): this code assumes that variables do not alias. - // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! - tc.CreateVariable(index_, arg.name, arg.value.type, arg.value.handle) - .IgnoreError(); + OP_REQUIRES_OK(ctx, tc.CreateVariable(index_, arg.name, arg.value.type, + arg.value.handle)); ctx->SetVariableOutput(0, index_); } else if (arg.value.is_constant) { ctx->SetConstantOutput(0, arg.value.constant_value); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 71a994b451..ba975d617d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -76,7 +76,8 @@ int64 XlaCompiler::NextStepId() { static void PruneUnreachableNodes(Graph* graph) { std::unordered_set<const Node*> nodes; for (Node* node : graph->nodes()) { - if (node->type_string() == "_Retval") { + if (node->type_string() == "_Retval" || + StringPiece(node->type_string()).ends_with("Send")) { nodes.insert(node); } } @@ -379,7 +380,6 @@ Status XlaCompiler::CompileGraph(string const& name, VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; xla::ComputationBuilder builder(client(), name); - XlaContext* context = new XlaContext(this, &builder, options_.allow_cpu_custom_calls, options_.resolve_compile_time_constants); diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 7316328f44..acab6d7b89 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -65,7 +65,7 @@ REGISTER_OP("ReadVariableOp") return errors::InvalidArgument( "Trying to read variable with wrong dtype. " "Expected ", - handle_dtype, " got ", value_dtype); + DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype)); } c->set_output(0, c->input_handle_shape(0)); return Status::OK(); @@ -96,7 +96,7 @@ REGISTER_OP("_UnsafeReadVariable") return errors::InvalidArgument( "Trying to read variable with wrong dtype. " "Expected ", - handle_dtype, " got ", value_dtype); + DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype)); } c->set_output(0, c->input_handle_shape(0)); return Status::OK(); @@ -137,7 +137,7 @@ Status CreateAssignShapeFn(InferenceContext* c) { return errors::InvalidArgument( "Trying to initialize handle for variable with wrong dtype. " "Expected ", - handle_dtype, " got ", value_dtype); + DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype)); } ShapeHandle s = c->input_handle_shape(0); ShapeHandle value_shape = c->input(1); |