diff options
author | 2016-07-14 20:06:30 -0800 | |
---|---|---|
committer | 2016-07-14 21:17:26 -0700 | |
commit | 7cc7b56b8a605f52d717173122a382cadd611793 (patch) | |
tree | 629584b639d0ab0aa21e3788f0ef8132baf7c297 /tensorflow/core/public | |
parent | 333b69580537bf14a3072f1388de64eb3fb5ebc2 (diff) |
Implementing more C API functions (executing more TODOs):
* TF_NodeGetAttrValueProto()
* TF_SessionPRunSetup()
* TF_SessionPRun()
Also:
* Clarify that it isn't required to set attrs that have defaults.
* Fewer string conversions in TF_SessionRun().
* Add "const" to a few more arguments.
* Some non-functional changes to make things more consistent.
Change: 127505523
Diffstat (limited to 'tensorflow/core/public')
-rw-r--r-- | tensorflow/core/public/tensor_c_api.h | 38 |
1 files changed, 35 insertions, 3 deletions
diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h index bd45c31ed7..9f4f7adde9 100644 --- a/tensorflow/core/public/tensor_c_api.h +++ b/tensorflow/core/public/tensor_c_api.h @@ -304,7 +304,8 @@ extern void TF_AddInputList(TF_NodeDescription* desc, const TF_Port* inputs, extern void TF_AddControlInput(TF_NodeDescription* desc, TF_Node* input); // Call some TF_SetAttr*() function for every attr that is not -// inferred from an input. +// inferred from an input and doesn't have a default value you wish to +// keep. // `value` must point to a string of length `length` bytes. extern void TF_SetAttrString(TF_NodeDescription* desc, const char* attr_name, @@ -449,6 +450,12 @@ extern int TF_NodeNumControlOutputs(TF_Node* node); extern int TF_NodeGetControlOutputs(TF_Node* node, TF_Node** control_outputs, int max_control_outputs); +// Sets `output_attr_value` to the binary-serialized AttrValue proto +// representation of the value of the `attr_name` attr of `node`. +extern void TF_NodeGetAttrValueProto(TF_Node* node, const char* attr_name, + TF_Buffer* output_attr_value, + TF_Status* status); + // Returns the node in the graph with `node_name`. Returns nullptr if // no node found. extern TF_Node* TF_GraphNodeByName(TF_Graph* graph, const char* node_name); @@ -531,7 +538,32 @@ extern void TF_SessionRun(TF_SessionWithGraph* session, // Output status TF_Status*); -// TODO(josh11b): TF_SessionPRunSetup() and TF_SessionPRun(). +// See TF_PRunSetup() below. +extern void TF_SessionPRunSetup(TF_SessionWithGraph*, + // Input names + const TF_Port* inputs, int ninputs, + // Output names + const TF_Port* outputs, int noutputs, + // Target nodes + const TF_Node* const* target_nodes, + int ntargets, + // Output handle + const char** handle, + // Output status + TF_Status*); + +// See TF_PRun() below. +extern void TF_SessionPRun(TF_SessionWithGraph*, const char* handle, + // Input tensors + const TF_Port* inputs, + TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Port* outputs, TF_Tensor** output_values, + int noutputs, + // Target nodes + const TF_Node* const* target_nodes, int ntargets, + // Output status + TF_Status*); // -------------------------------------------------------------------------- // The deprecated session API. Please switch to the above instead of @@ -616,7 +648,7 @@ extern void TF_PRunSetup(TF_Session*, // Target nodes const char** target_node_names, int ntargets, // Output handle - char** handle, + const char** handle, // Output status TF_Status*); |