aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-03-06 13:38:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 13:42:58 -0800
commit1d64f9038084095bf92a8ca120d7e1f34ec24ac9 (patch)
tree0a521bc53e682b7ecec1aab5519b983d02a503c4 /tensorflow/c/c_api.cc
parentad08baa5c27ab063596116a178ccff7d3796df65 (diff)
Add TF_TryEvaluateConstant to the C API and have smart_cond call it.
This effectively plumbs EvaluateConstantTensor to smart_cond. This makes smart_cond even smarter by trying to evaluate the predicate if it can't statically infer it. PiperOrigin-RevId: 188073244
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 85f1d1639b..3d0e886476 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -30,6 +30,7 @@ limitations under the License.
#endif
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/eval_const_tensor.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/log_memory.h"
@@ -73,6 +74,7 @@ using tensorflow::NodeBuilder;
using tensorflow::NodeDef;
using tensorflow::OpDef;
using tensorflow::OpRegistry;
+using tensorflow::OutputTensor;
using tensorflow::PartialTensorShape;
using tensorflow::RunMetadata;
using tensorflow::RunOptions;
@@ -2682,6 +2684,24 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
output_values, target_names, nullptr, status);
}
+unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
+ TF_Tensor** result, TF_Status* status) {
+ *result = nullptr;
+ mutex_lock l(graph->mu);
+ OutputTensor tensor(&output.oper->node, output.index);
+ bool evaluated;
+ Tensor result_tensor;
+ status->status = EvaluateConstantTensor(
+ tensor, graph->refiner, *graph->graph.op_registry(),
+ graph->graph.versions().producer(), &evaluated, &result_tensor);
+ if (evaluated) {
+ DCHECK(status->status.ok());
+ *result = TF_TensorFromTensor(result_tensor, status);
+ if (!status->status.ok()) evaluated = false;
+ }
+ return evaluated;
+}
+
TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
tensorflow::OpList op_list;
if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {