aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h')
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h15
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
index 8830f7efe7..d77aa2c5ab 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
@@ -24,7 +24,10 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
+
+#include "cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"
+
namespace tensorflow {
namespace tensorrt {
// This class provides a 1 element queue to match TFs push model to
@@ -39,8 +42,9 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
int getBatchSize() const override;
bool getBatch(void* bindings[], const char* names[],
int num_bindings) override;
- bool setBatch(const std::unordered_map<string, void*>& data);
- void setDone() { done_ = true; }
+ bool setBatch(const std::unordered_map<string, void*>& data,
+ const cudaStream_t stream);
+ void setDone();
const void* readCalibrationCache(std::size_t& length) override;
void writeCalibrationCache(const void* ptr, std::size_t length) override;
~TRTInt8Calibrator();
@@ -55,11 +59,14 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
const std::unordered_map<string, std::pair<void*, size_t>>
dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with
// buffer names
- std::atomic_bool calib_running_;
+ bool calib_running_;
+ bool batch_is_set_;
string engine_name_;
};
+
} // namespace tensorrt
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_
+
#endif
#endif
+#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_