From 79c828ea6ddbcfccd43a2be176fc1dcad4daf34e Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Tue, 26 Jun 2018 15:36:59 -0700 Subject: Support shapes for remote eager tensor handles. Since we respond with the shape, all RPCs will happen sync (note that we may still hide the python overhead, since the op is still scheduled for execution via the eager executor). PiperOrigin-RevId: 202207324 --- tensorflow/c/eager/c_api.cc | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) (limited to 'tensorflow/c') diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 00b474fe86..82ca2be2cf 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -156,12 +156,14 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error // message. -#define LOG_AND_RETURN_IF_ERROR(...) \ - do { \ - const ::tensorflow::Status _status = (__VA_ARGS__); \ - LOG(ERROR) << _status.error_message(); \ - if (TF_PREDICT_FALSE(!_status.ok())) return _status; \ - } while (0) +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.error_message(); \ + return _status; \ + } \ + } while (0); string worker_name = tensorflow::strings::StrCat( "/job:", opts->server_def.job_name(), @@ -346,16 +348,16 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->Tensor(&t); - return t == nullptr ? 0 : t->dims(); + int result; + status->status = h->handle->NumDims(&result); + return result; } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->Tensor(&t); - return t == nullptr ? 0 : t->dim_size(dim_index); + tensorflow::int64 result; + status->status = h->handle->Dim(dim_index, &result); + return result; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { -- cgit v1.2.3