aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-26 15:36:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-26 15:43:10 -0700
commit79c828ea6ddbcfccd43a2be176fc1dcad4daf34e (patch)
treeadd42271d2d23096b03af3bf8fb89384cffc9a54 /tensorflow/c
parentec34de06981eed74c2c2a47c8a6372735e9d3622 (diff)
Support shapes for remote eager tensor handles.
Since we respond with the shape, all RPCs will happen sync (note that we may still hide the python overhead, since the op is still scheduled for execution via the eager executor). PiperOrigin-RevId: 202207324
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/eager/c_api.cc26
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 00b474fe86..82ca2be2cf 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -156,12 +156,14 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
-#define LOG_AND_RETURN_IF_ERROR(...) \
- do { \
- const ::tensorflow::Status _status = (__VA_ARGS__); \
- LOG(ERROR) << _status.error_message(); \
- if (TF_PREDICT_FALSE(!_status.ok())) return _status; \
- } while (0)
+#define LOG_AND_RETURN_IF_ERROR(...) \
+ do { \
+ const ::tensorflow::Status _status = (__VA_ARGS__); \
+ if (TF_PREDICT_FALSE(!_status.ok())) { \
+ LOG(ERROR) << _status.error_message(); \
+ return _status; \
+ } \
+ } while (0);
string worker_name = tensorflow::strings::StrCat(
"/job:", opts->server_def.job_name(),
@@ -346,16 +348,16 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
- const tensorflow::Tensor* t = nullptr;
- status->status = h->handle->Tensor(&t);
- return t == nullptr ? 0 : t->dims();
+ int result;
+ status->status = h->handle->NumDims(&result);
+ return result;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
- const tensorflow::Tensor* t = nullptr;
- status->status = h->handle->Tensor(&t);
- return t == nullptr ? 0 : t->dim_size(dim_index);
+ tensorflow::int64 result;
+ status->status = h->handle->Dim(dim_index, &result);
+ return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {