diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-03-06 13:38:56 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-06 13:42:58 -0800 |
commit | 1d64f9038084095bf92a8ca120d7e1f34ec24ac9 (patch) | |
tree | 0a521bc53e682b7ecec1aab5519b983d02a503c4 /tensorflow/c/c_api.cc | |
parent | ad08baa5c27ab063596116a178ccff7d3796df65 (diff) |
Add TF_TryEvaluateConstant to the C API and have smart_cond call it.
This effectively plumbs EvaluateConstantTensor to smart_cond. This makes smart_cond even smarter by trying to evaluate the predicate
if it can't statically infer it.
PiperOrigin-RevId: 188073244
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 85f1d1639b..3d0e886476 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -30,6 +30,7 @@ limitations under the License. #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eval_const_tensor.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" @@ -73,6 +74,7 @@ using tensorflow::NodeBuilder; using tensorflow::NodeDef; using tensorflow::OpDef; using tensorflow::OpRegistry; +using tensorflow::OutputTensor; using tensorflow::PartialTensorShape; using tensorflow::RunMetadata; using tensorflow::RunOptions; @@ -2682,6 +2684,24 @@ void TF_SessionPRun(TF_Session* session, const char* handle, output_values, target_names, nullptr, status); } +unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, + TF_Tensor** result, TF_Status* status) { + *result = nullptr; + mutex_lock l(graph->mu); + OutputTensor tensor(&output.oper->node, output.index); + bool evaluated; + Tensor result_tensor; + status->status = EvaluateConstantTensor( + tensor, graph->refiner, *graph->graph.op_registry(), + graph->graph.versions().producer(), &evaluated, &result_tensor); + if (evaluated) { + DCHECK(status->status.ok()); + *result = TF_TensorFromTensor(result_tensor, status); + if (!status->status.ok()) evaluated = false; + } + return evaluated; +} + TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { tensorflow::OpList op_list; if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { |